From ba617db6a28f02481a8c6604878243af0393a85f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 6 Jun 2025 00:13:50 -0400 Subject: [PATCH 01/54] wip Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 84 +++++++++++++++++++ .../modifiers/transform/__init__.py | 3 + .../modifiers/transform/template/quip.py | 41 +++++++++ .../modifiers/transform/template/spinquant.py | 65 ++++++++++++++ .../modifiers/transform/transform.py | 28 +++++++ 5 files changed, 221 insertions(+) create mode 100644 examples/transform/llama3_example.py create mode 100644 src/llmcompressor/modifiers/transform/__init__.py create mode 100644 src/llmcompressor/modifiers/transform/template/quip.py create mode 100644 src/llmcompressor/modifiers/transform/template/spinquant.py create mode 100644 src/llmcompressor/modifiers/transform/transform.py diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py new file mode 100644 index 000000000..0c976874d --- /dev/null +++ b/examples/transform/llama3_example.py @@ -0,0 +1,84 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.modifiers.transform import TransformModifier +from llmcompressor.transformers import oneshot + +# Select model and load it. +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + device_map="auto", + torch_dtype="auto", +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = [ + TransformModifier(), + GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) +] + +# Apply algorithms. +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py new file mode 100644 index 000000000..85e8972b4 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .transform import TransformModifier \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/template/quip.py b/src/llmcompressor/modifiers/transform/template/quip.py new file mode 100644 index 000000000..070fec03a --- /dev/null +++ b/src/llmcompressor/modifiers/transform/template/quip.py @@ -0,0 +1,41 @@ +from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig + + +QUIP = TransformConfig( + config_groups={ + "v": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["Linear"], + location="input", # non-mergable + ignore="lm_head", + ), + TransformArgs( + targets=["Linear"], + location="weight_input", + inverse=True, + ignore="lm_head", + ), + ], + randomize=True, + ), + "u": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["Linear"], + location="weight_output", + ignore="lm_head", + ), + TransformArgs( + targets=["Linear"], + location="output", # non-mergable + inverse=True, + ignore="lm_head" + ), + ], + randomize=True, + ), + } +) \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/template/spinquant.py b/src/llmcompressor/modifiers/transform/template/spinquant.py new file mode 100644 index 000000000..b9d7c5844 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/template/spinquant.py @@ -0,0 +1,65 @@ +from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig + + +LLAMA_SPINQUANT = TransformConfig( + transform_groups={ + "R1": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["embed_tokens", "o_proj", "down_proj"], + location="weight_output", + ), + TransformArgs( + targets=[ + "q_proj", + "k_proj", + "v_proj", + "up_proj", + "gate_proj", + "lm_head", + ], + location="weight_input", + inverse=True, + ), + ], + ), + "R2": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["v_proj"], + location="weight_output", + ), + TransformArgs( + targets=["o_proj"], location="weight_input", inverse=True + ), + ], + ), + "R3": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["self_attn"], + location="k_cache", + ), + TransformArgs( + targets=["self_attn"], + location="q_attn", + ), + ], + ), + "R4": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["down_proj"], + location="input", + ), + TransformArgs( + targets=["down_proj"], location="weight_input", inverse=True + ), + ], + ), + } +) \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py new file mode 100644 index 000000000..1700e12f1 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -0,0 +1,28 @@ +from typing import Dict, Optional + +from llmcompressor.core import State +from llmcompressor.modifiers import Modifier + +from compressed_tensors.transform import TransformConfig, TransformScheme, TransformFactory + +from .template.quip import QUIP + +class TransformModifier(Modifier): + preset_config: Optional[str] = None + config_groups: Optional[Dict[str, TransformScheme]] = None + + # model validator to validate both preset and config gropus are not provided + + def on_initialize(self, state: State, **kwargs): + if self.preset_config is not None: + # import config template and customize to model + pass + + + #config = TransformConfig(config_groups=self.config_groups) + config = QUIP + + # TODO: use CT-provided apply_transform_config + for name, scheme in config.config_groups.items(): + factory = TransformFactory.from_scheme(scheme, name=name) + factory.apply_to_model(state.model) \ No newline at end of file From 2f5b1c8a20ddffd9f83cf2984d36007bf7cdefe5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 11 Jun 2025 23:15:46 -0400 Subject: [PATCH 02/54] use random-hadamard, add correctness tests Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 2 +- src/llmcompressor/modifiers/transform/__init__.py | 2 +- .../modifiers/transform/template/quip.py | 11 +++++------ .../modifiers/transform/template/spinquant.py | 5 ++--- src/llmcompressor/modifiers/transform/transform.py | 14 ++++++-------- 5 files changed, 15 insertions(+), 19 deletions(-) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 0c976874d..41bb4921c 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -58,7 +58,7 @@ def tokenize(sample): # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ TransformModifier(), - GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) + GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] # Apply algorithms. diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index 85e8972b4..6c65678af 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,3 +1,3 @@ # flake8: noqa -from .transform import TransformModifier \ No newline at end of file +from .transform import TransformModifier diff --git a/src/llmcompressor/modifiers/transform/template/quip.py b/src/llmcompressor/modifiers/transform/template/quip.py index 070fec03a..e39c32e6d 100644 --- a/src/llmcompressor/modifiers/transform/template/quip.py +++ b/src/llmcompressor/modifiers/transform/template/quip.py @@ -1,10 +1,9 @@ -from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig - +from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme QUIP = TransformConfig( config_groups={ "v": TransformScheme( - type="hadamard", + type="random-hadamard", apply=[ TransformArgs( targets=["Linear"], @@ -21,7 +20,7 @@ randomize=True, ), "u": TransformScheme( - type="hadamard", + type="random-hadamard", apply=[ TransformArgs( targets=["Linear"], @@ -32,10 +31,10 @@ targets=["Linear"], location="output", # non-mergable inverse=True, - ignore="lm_head" + ignore="lm_head", ), ], randomize=True, ), } -) \ No newline at end of file +) diff --git a/src/llmcompressor/modifiers/transform/template/spinquant.py b/src/llmcompressor/modifiers/transform/template/spinquant.py index b9d7c5844..d628cbfd9 100644 --- a/src/llmcompressor/modifiers/transform/template/spinquant.py +++ b/src/llmcompressor/modifiers/transform/template/spinquant.py @@ -1,5 +1,4 @@ -from compressed_tensors.transform import TransformArgs, TransformScheme, TransformConfig - +from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme LLAMA_SPINQUANT = TransformConfig( transform_groups={ @@ -62,4 +61,4 @@ ], ), } -) \ No newline at end of file +) diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py index 1700e12f1..6cd1417b5 100644 --- a/src/llmcompressor/modifiers/transform/transform.py +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -1,12 +1,13 @@ from typing import Dict, Optional +from compressed_tensors.transform import TransformScheme, apply_transform_config + from llmcompressor.core import State from llmcompressor.modifiers import Modifier -from compressed_tensors.transform import TransformConfig, TransformScheme, TransformFactory - from .template.quip import QUIP + class TransformModifier(Modifier): preset_config: Optional[str] = None config_groups: Optional[Dict[str, TransformScheme]] = None @@ -18,11 +19,8 @@ def on_initialize(self, state: State, **kwargs): # import config template and customize to model pass - - #config = TransformConfig(config_groups=self.config_groups) + # config = TransformConfig(config_groups=self.config_groups) config = QUIP - # TODO: use CT-provided apply_transform_config - for name, scheme in config.config_groups.items(): - factory = TransformFactory.from_scheme(scheme, name=name) - factory.apply_to_model(state.model) \ No newline at end of file + apply_transform_config(state.model, config) + breakpoint() From 3aa35e727143ee35cc1226fe86d863de8eff85df Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 11 Jun 2025 23:22:24 -0400 Subject: [PATCH 03/54] add correctness test, note that precision makes a large difference Signed-off-by: Kyle Sayers --- .../modifiers/transform/test_correctness.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tests/llmcompressor/modifiers/transform/test_correctness.py diff --git a/tests/llmcompressor/modifiers/transform/test_correctness.py b/tests/llmcompressor/modifiers/transform/test_correctness.py new file mode 100644 index 000000000..8fca9639b --- /dev/null +++ b/tests/llmcompressor/modifiers/transform/test_correctness.py @@ -0,0 +1,29 @@ +import pytest +import torch +from compressed_tensors.transform import apply_transform_config +from transformers import AutoModelForCausalLM + +from llmcompressor.modifiers.transform.template.quip import QUIP + + +@pytest.mark.parametrize( + "dtype,exp_max,exp_mse", [ + (torch.bfloat16, 1.1, 0.012), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501 + (torch.float32, 4e-4, 2e-9) + ] +) +def test_apply_correctness(dtype, exp_max, exp_mse): + model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Meta-Llama-3-8B-Instruct", device_map="cuda", torch_dtype=dtype + ) + + input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()} + with torch.no_grad(): + true_output = model(**input) + + apply_transform_config(model, QUIP) + with torch.no_grad(): + output = model(**input) + + assert torch.max(true_output.logits - output.logits) <= exp_max + assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse From b6c088e787454b419962544aa2ce9f852b73692a Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 23 Jun 2025 20:18:52 +0000 Subject: [PATCH 04/54] add on lifecycle methods Signed-off-by: Brian Dellabetta --- examples/transform/llama3_example.py | 4 +-- .../modifiers/transform/transform.py | 33 ++++++++++++++++--- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 41bb4921c..b868d4b2a 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -3,14 +3,13 @@ from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.modifiers.transform import TransformModifier -from llmcompressor.transformers import oneshot +from llmcompressor import oneshot # Select model and load it. MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, - device_map="auto", torch_dtype="auto", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) @@ -66,6 +65,7 @@ def tokenize(sample): model=model, dataset=ds, recipe=recipe, + pipeline="sequential", max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py index 6cd1417b5..6b8e89927 100644 --- a/src/llmcompressor/modifiers/transform/transform.py +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -2,7 +2,7 @@ from compressed_tensors.transform import TransformScheme, apply_transform_config -from llmcompressor.core import State +from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from .template.quip import QUIP @@ -12,9 +12,9 @@ class TransformModifier(Modifier): preset_config: Optional[str] = None config_groups: Optional[Dict[str, TransformScheme]] = None - # model validator to validate both preset and config gropus are not provided + # model validator to validate both preset and config groups are not provided - def on_initialize(self, state: State, **kwargs): + def on_initialize(self, state: State, **kwargs) -> bool: if self.preset_config is not None: # import config template and customize to model pass @@ -23,4 +23,29 @@ def on_initialize(self, state: State, **kwargs): config = QUIP apply_transform_config(state.model, config) - breakpoint() + + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + pass + + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + + def on_finalize(self, state: State, **kwargs) -> bool: + if not self.ended_: + self.on_end(state, None) + + return True From 320712434f3bdbae7330f3d6dc2a4f0f0224a497 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 2 Jul 2025 15:07:42 +0000 Subject: [PATCH 05/54] TransformModifier with SpinQuant R1&R2 Signed-off-by: Brian Dellabetta --- examples/transform/llama3_example.py | 26 +++++------ .../modifiers/transform/__init__.py | 1 + .../modifiers/transform/presets/__init__.py | 8 ++++ .../transform/{template => presets}/quip.py | 0 .../{template => presets}/spinquant.py | 43 +++++++++++++++++++ .../modifiers/transform/transform.py | 31 +++++++------ 6 files changed, 84 insertions(+), 25 deletions(-) create mode 100644 src/llmcompressor/modifiers/transform/presets/__init__.py rename src/llmcompressor/modifiers/transform/{template => presets}/quip.py (100%) rename src/llmcompressor/modifiers/transform/{template => presets}/spinquant.py (61%) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index b868d4b2a..90051c9a8 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -1,9 +1,10 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer -from llmcompressor.modifiers.quantization import GPTQModifier -from llmcompressor.modifiers.transform import TransformModifier from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier +from llmcompressor.modifiers.transform import TransformModifier +from llmcompressor.utils import dispatch_for_generation # Select model and load it. MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" @@ -56,8 +57,8 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ - TransformModifier(), - GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), + TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"), + QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] # Apply algorithms. @@ -70,15 +71,16 @@ def tokenize(sample): num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) -# Confirm generations of the quantized model look sane. -print("\n\n") -print("========== SAMPLE GENERATION ==============") -input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=100) -print(tokenizer.decode(output[0])) -print("==========================================\n\n") +# # Confirm generations of the quantized model look sane. +# print("\n\n") +# print("========== SAMPLE GENERATION ==============") +# dispatch_for_generation(model) +# input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +# output = model.generate(input_ids, max_new_tokens=100) +# print(tokenizer.decode(output[0])) +# print("==========================================\n\n") # Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" +SAVE_DIR = MODEL_ID.split("/")[1] + "-transform-quant-w4a16" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index 6c65678af..036d35b60 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa from .transform import TransformModifier +from .transform.presets import TRANSFORM_PRESETS diff --git a/src/llmcompressor/modifiers/transform/presets/__init__.py b/src/llmcompressor/modifiers/transform/presets/__init__.py new file mode 100644 index 000000000..a36bbc4d1 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/presets/__init__.py @@ -0,0 +1,8 @@ +from .quip import QUIP +from .spinquant import LLAMA_SPINQUANT, LLAMA_SPINQUANT_R1R2 + +TRANSFORM_PRESETS = { + "QUIP": QUIP, + "LLAMA_SPINQUANT": LLAMA_SPINQUANT, + "LLAMA_SPINQUANT_R1R2": LLAMA_SPINQUANT_R1R2, +} diff --git a/src/llmcompressor/modifiers/transform/template/quip.py b/src/llmcompressor/modifiers/transform/presets/quip.py similarity index 100% rename from src/llmcompressor/modifiers/transform/template/quip.py rename to src/llmcompressor/modifiers/transform/presets/quip.py diff --git a/src/llmcompressor/modifiers/transform/template/spinquant.py b/src/llmcompressor/modifiers/transform/presets/spinquant.py similarity index 61% rename from src/llmcompressor/modifiers/transform/template/spinquant.py rename to src/llmcompressor/modifiers/transform/presets/spinquant.py index d628cbfd9..194818b38 100644 --- a/src/llmcompressor/modifiers/transform/template/spinquant.py +++ b/src/llmcompressor/modifiers/transform/presets/spinquant.py @@ -1,5 +1,8 @@ from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme +# Ref: https://arxiv.org/pdf/2405.16406 Fig 1 + +# All rotations LLAMA_SPINQUANT = TransformConfig( transform_groups={ "R1": TransformScheme( @@ -62,3 +65,43 @@ ), } ) + + +# Mergeable rotations R1 and R2 only +LLAMA_SPINQUANT_R1R2 = TransformConfig( + config_groups={ + "R1": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["embed_tokens", "o_proj", "down_proj"], + location="weight_output", + ), + TransformArgs( + targets=[ + "q_proj", + "k_proj", + "v_proj", + "up_proj", + "gate_proj", + "lm_head", + ], + location="weight_input", + inverse=True, + ), + ], + ), + "R2": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["v_proj"], + location="weight_output", + ), + TransformArgs( + targets=["o_proj"], location="weight_input", inverse=True + ), + ], + ), + } +) diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py index 6b8e89927..d7ac10aaa 100644 --- a/src/llmcompressor/modifiers/transform/transform.py +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -1,28 +1,33 @@ -from typing import Dict, Optional +from typing import Optional -from compressed_tensors.transform import TransformScheme, apply_transform_config +from compressed_tensors.transform import TransformConfig, apply_transform_config +from pydantic import ValidationError, model_validator from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier - -from .template.quip import QUIP +from llmcompressor.modifiers.transform.presets import TRANSFORM_PRESETS class TransformModifier(Modifier): preset_config: Optional[str] = None - config_groups: Optional[Dict[str, TransformScheme]] = None + config: Optional[TransformConfig] = None # model validator to validate both preset and config groups are not provided + @model_validator(mode="after") + def validate_model_after(model: "TransformModifier") -> "TransformModifier": + if model.preset_config is None and model.config is None: + raise ValidationError("Either a config or a preset_config must be provided") + + if model.preset_config is not None: + if model.preset_config not in TRANSFORM_PRESETS: + raise ValidationError( + f"Invalid preset_config '{model.preset_config}' " + f"must be in {TRANSFORM_PRESETS.keys()}" + ) + model.config = TRANSFORM_PRESETS[model.preset_config] def on_initialize(self, state: State, **kwargs) -> bool: - if self.preset_config is not None: - # import config template and customize to model - pass - - # config = TransformConfig(config_groups=self.config_groups) - config = QUIP - - apply_transform_config(state.model, config) + apply_transform_config(state.model, self.config) return True From a88ca3c0ef4866c4239a7f34ca62ed90f9554586 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 2 Jul 2025 18:59:36 +0000 Subject: [PATCH 06/54] spinquant and quip_online, running but outputting gibberish Signed-off-by: Brian Dellabetta --- .../modifiers/transform/__init__.py | 2 +- .../modifiers/transform/presets/__init__.py | 3 +- .../modifiers/transform/presets/quip.py | 58 ++++++++++++++ .../modifiers/transform/presets/spinquant.py | 78 ++++++------------- .../modifiers/transform/transform.py | 2 + .../modifiers/transform/test_correctness.py | 13 +++- 6 files changed, 95 insertions(+), 61 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index 036d35b60..c43958136 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa +from .presets import TRANSFORM_PRESETS from .transform import TransformModifier -from .transform.presets import TRANSFORM_PRESETS diff --git a/src/llmcompressor/modifiers/transform/presets/__init__.py b/src/llmcompressor/modifiers/transform/presets/__init__.py index a36bbc4d1..0d4a06b90 100644 --- a/src/llmcompressor/modifiers/transform/presets/__init__.py +++ b/src/llmcompressor/modifiers/transform/presets/__init__.py @@ -1,8 +1,9 @@ -from .quip import QUIP +from .quip import QUIP, QUIP_ONLINE from .spinquant import LLAMA_SPINQUANT, LLAMA_SPINQUANT_R1R2 TRANSFORM_PRESETS = { "QUIP": QUIP, + "QUIP_ONLINE": QUIP_ONLINE, "LLAMA_SPINQUANT": LLAMA_SPINQUANT, "LLAMA_SPINQUANT_R1R2": LLAMA_SPINQUANT_R1R2, } diff --git a/src/llmcompressor/modifiers/transform/presets/quip.py b/src/llmcompressor/modifiers/transform/presets/quip.py index e39c32e6d..4ce5e47ae 100644 --- a/src/llmcompressor/modifiers/transform/presets/quip.py +++ b/src/llmcompressor/modifiers/transform/presets/quip.py @@ -38,3 +38,61 @@ ), } ) + +# https://github.com/vllm-project/llm-compressor/blob/b43b27a2f277a5e62be4f8c713b84fd1c7aa116b/weight_transform.py#L24-L105 +QUIP_ONLINE = TransformConfig( + config_groups={ + "u_transform_q_o_down_proj": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=[ + "re:.*.attn.q_proj$", + "re:.*.attn.o_proj$", + "re:.*.mlp.down_proj$", + ], + location="weight_input", + ) + ], + ), + "u_transform_k_v_proj": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["re:.*.attn.k_proj$", "re:.*.attn.v_proj$"], + location="weight_input", + ) + ], + ), + "u_transform_gate_up_proj": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["re:.*.mlp.gate_proj$", "re:.*.mlp.up_proj$"], + location="weight_input", + ) + ], + ), + "v_transform_linear": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["Linear"], + location="weight_output", + ignore=["re:.*.mlp.down_proj$", "lm_head"], + inverse=True, + ) + ], + ), + "v_transform_down_proj": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=["re:.*.mlp.down_proj$"], + location="weight_output", + inverse=True, + ) + ], + ), + } +) diff --git a/src/llmcompressor/modifiers/transform/presets/spinquant.py b/src/llmcompressor/modifiers/transform/presets/spinquant.py index 194818b38..555b03fd6 100644 --- a/src/llmcompressor/modifiers/transform/presets/spinquant.py +++ b/src/llmcompressor/modifiers/transform/presets/spinquant.py @@ -2,23 +2,23 @@ # Ref: https://arxiv.org/pdf/2405.16406 Fig 1 -# All rotations -LLAMA_SPINQUANT = TransformConfig( - transform_groups={ +# Mergeable rotations R1 and R2 only +LLAMA_SPINQUANT_R1R2 = TransformConfig( + config_groups={ "R1": TransformScheme( type="hadamard", apply=[ TransformArgs( - targets=["embed_tokens", "o_proj", "down_proj"], + targets=["re:.*embed_tokens$", "re:.*o_proj$", "re:.*down_proj$"], location="weight_output", ), TransformArgs( targets=[ - "q_proj", - "k_proj", - "v_proj", - "up_proj", - "gate_proj", + "re:.*q_proj$", + "re:.*k_proj$", + "re:.*v_proj$", + "re:.*up_proj$", + "re:.*gate_proj$", "lm_head", ], location="weight_input", @@ -30,23 +30,31 @@ type="hadamard", apply=[ TransformArgs( - targets=["v_proj"], + targets=["re:.*v_proj$"], location="weight_output", ), TransformArgs( - targets=["o_proj"], location="weight_input", inverse=True + targets=["re:.*o_proj$"], location="weight_input", inverse=True ), ], ), + } +) + +# All rotations +LLAMA_SPINQUANT = TransformConfig( + config_groups={ + "R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"], + "R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"], "R3": TransformScheme( type="hadamard", apply=[ TransformArgs( - targets=["self_attn"], + targets=["re:.*self_attn$"], location="k_cache", ), TransformArgs( - targets=["self_attn"], + targets=["re:.*self_attn$"], location="q_attn", ), ], @@ -55,51 +63,11 @@ type="hadamard", apply=[ TransformArgs( - targets=["down_proj"], + targets=["re:.*down_proj$"], location="input", ), TransformArgs( - targets=["down_proj"], location="weight_input", inverse=True - ), - ], - ), - } -) - - -# Mergeable rotations R1 and R2 only -LLAMA_SPINQUANT_R1R2 = TransformConfig( - config_groups={ - "R1": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["embed_tokens", "o_proj", "down_proj"], - location="weight_output", - ), - TransformArgs( - targets=[ - "q_proj", - "k_proj", - "v_proj", - "up_proj", - "gate_proj", - "lm_head", - ], - location="weight_input", - inverse=True, - ), - ], - ), - "R2": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["v_proj"], - location="weight_output", - ), - TransformArgs( - targets=["o_proj"], location="weight_input", inverse=True + targets=["re:.*down_proj$"], location="weight_input", inverse=True ), ], ), diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py index d7ac10aaa..e94a3dc35 100644 --- a/src/llmcompressor/modifiers/transform/transform.py +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -26,6 +26,8 @@ def validate_model_after(model: "TransformModifier") -> "TransformModifier": ) model.config = TRANSFORM_PRESETS[model.preset_config] + return model + def on_initialize(self, state: State, **kwargs) -> bool: apply_transform_config(state.model, self.config) diff --git a/tests/llmcompressor/modifiers/transform/test_correctness.py b/tests/llmcompressor/modifiers/transform/test_correctness.py index 8fca9639b..660bab0ef 100644 --- a/tests/llmcompressor/modifiers/transform/test_correctness.py +++ b/tests/llmcompressor/modifiers/transform/test_correctness.py @@ -7,10 +7,15 @@ @pytest.mark.parametrize( - "dtype,exp_max,exp_mse", [ - (torch.bfloat16, 1.1, 0.012), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501 - (torch.float32, 4e-4, 2e-9) - ] + "dtype,exp_max,exp_mse", + [ + ( + torch.bfloat16, + 1.1, + 0.012, + ), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501 + (torch.float32, 4e-4, 2e-9), + ], ) def test_apply_correctness(dtype, exp_max, exp_mse): model = AutoModelForCausalLM.from_pretrained( From 5bd51df3668e7be7b2ea969bf85e1ae7f528d8ee Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 2 Jul 2025 19:11:20 +0000 Subject: [PATCH 07/54] updated example Signed-off-by: Brian Dellabetta --- examples/transform/llama3_example.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 90051c9a8..62801935e 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -7,7 +7,7 @@ from llmcompressor.utils import dispatch_for_generation # Select model and load it. -MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" # "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, @@ -57,6 +57,10 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ + # TODO preset_config="LLAMA_SPINQUANT_R1R2" outputs gibberish + # TODO preset_config="QUIP_ONLINE" outputs gibberish + # preset_config="QUIP" output sensible, but cannot load saved + # checkpoint or run evals (~4hrs to run) TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"), QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] @@ -72,12 +76,12 @@ def tokenize(sample): ) # # Confirm generations of the quantized model look sane. -# print("\n\n") -# print("========== SAMPLE GENERATION ==============") -# dispatch_for_generation(model) -# input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -# output = model.generate(input_ids, max_new_tokens=100) -# print(tokenizer.decode(output[0])) +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) # print("==========================================\n\n") # Save to disk compressed. From 3c216dd685fdac9172213b1200be5a5ee91be532 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 8 Jul 2025 21:29:27 +0000 Subject: [PATCH 08/54] DummyModel script Signed-off-by: Brian Dellabetta --- examples/transform/llama3_example.py | 23 ++-- examples/transform/spinquant_dummy.py | 112 ++++++++++++++++++ src/llmcompressor/entrypoints/oneshot.py | 3 +- .../modifiers/transform/presets/spinquant.py | 43 ++++--- 4 files changed, 154 insertions(+), 27 deletions(-) create mode 100644 examples/transform/spinquant_dummy.py diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 62801935e..1ec7b6516 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -7,7 +7,9 @@ from llmcompressor.utils import dispatch_for_generation # Select model and load it. -MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" # "meta-llama/Meta-Llama-3-8B-Instruct" +# MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" +# MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" # TODO hidden size 3072 causes failure when creating hadamard +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, @@ -62,17 +64,18 @@ def tokenize(sample): # preset_config="QUIP" output sensible, but cannot load saved # checkpoint or run evals (~4hrs to run) TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"), - QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), + # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] # Apply algorithms. oneshot( model=model, - dataset=ds, recipe=recipe, - pipeline="sequential", - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, + # dataset=ds, + pipeline="datafree", + # max_seq_length=MAX_SEQUENCE_LENGTH, + # num_calibration_samples=NUM_CALIBRATION_SAMPLES, + log_dir=None, ) # # Confirm generations of the quantized model look sane. @@ -84,7 +87,7 @@ def tokenize(sample): print(tokenizer.decode(output[0])) # print("==========================================\n\n") -# Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-transform-quant-w4a16" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) +# # Save to disk compressed. +# SAVE_DIR = MODEL_ID.split("/")[1] + "-transform-quant-w4a16" +# model.save_pretrained(SAVE_DIR, save_compressed=True) +# tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/transform/spinquant_dummy.py b/examples/transform/spinquant_dummy.py new file mode 100644 index 000000000..494b6c611 --- /dev/null +++ b/examples/transform/spinquant_dummy.py @@ -0,0 +1,112 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +from compressed_tensors.utils import update_parameter_data +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier +from llmcompressor.modifiers.transform import TransformModifier +from llmcompressor.utils import dispatch_for_generation +from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, +) + +hidden_dim = intermediate_dim = 64 +up_dim = 128 +num_embeddings = 12 + + +# TODO remove file before merging + + +class DummySelfAttn(torch.nn.Module): + def __init__(self, hidden_dim, intermediate_dim): + super().__init__() + self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=None) + self.k_proj = torch.nn.Linear(hidden_dim, intermediate_dim, bias=None) + self.v_proj = torch.nn.Linear(hidden_dim, intermediate_dim, bias=None) + self.o_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=None) + self.num_heads = 1 + self.num_key_value_groups = 1 + + def forward(self, hidden_states): + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + ### EAGER ATTENTION + attn_weights = torch.matmul(q.T, k) + + attn_weights = torch.nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(q.dtype) + attn_output = torch.matmul(attn_weights, v.T) + attn_output = attn_output.T.contiguous() + + return self.o_proj(attn_output) + + +class DummyMLP(torch.nn.Module): + def __init__(self, hidden_dim, up_dim): + super().__init__() + self.up_proj = torch.nn.Linear(hidden_dim, up_dim, bias=None) + self.gate_proj = torch.nn.Linear(hidden_dim, up_dim, bias=None) + self.down_proj = torch.nn.Linear(up_dim, hidden_dim, bias=None) + self.act_fn = torch.nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class DummyModel(torch.nn.Module): + def __init__(self, num_embeddings, hidden_dim, intermediate_dim, up_dim): + super().__init__() + self.embed_tokens = torch.nn.Embedding(num_embeddings, hidden_dim) + self.input_layernorm = LlamaRMSNorm(hidden_dim) + self.post_attention_layernorm = LlamaRMSNorm(hidden_dim) + self.self_attn = DummySelfAttn(hidden_dim, intermediate_dim) + self.mlp = DummyMLP(hidden_dim, up_dim) + self.lm_head = torch.nn.Linear(hidden_dim, num_embeddings, bias=None) + + def forward(self, input_ids): + x = self.embed_tokens(input_ids) + x = self.input_layernorm(x) + x = self.self_attn(x) + x = self.post_attention_layernorm(x) + x = self.mlp(x) + return self.lm_head(x) + + +model = DummyModel(num_embeddings, hidden_dim, intermediate_dim, up_dim) + +# TODO Uncomment this to see norm diff > 1e-6 +# This is due to issue Kyle spotted in https://arxiv.org/pdf/2405.16406 Page 5 Footnote 2 +# Will have to fuse layernorms with subsequent layers so that input_layernorm.weight is equal to torch.ones() (this apparently makes it rotation invariant) +# https://github.com/facebookresearch/SpinQuant/blob/8f47aa3f00e8662caf1a484153920a07e5281c3a/utils/fuse_norm_utils.py#L39 +# update_parameter_data( +# model.input_layernorm, +# torch.rand(model.input_layernorm.weight.shape), +# "weight", +# ) + +input_ids = torch.IntTensor([1, 2, 3, 4, 5]) +orig_output = model(input_ids) + +recipe = [ + # NOTE: preset_config="QUIP" output sensible, but cannot load saved + # checkpoint or run evals (~4hrs to run) + TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"), + # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), +] + +oneshot( + model=model, + recipe=recipe, + pipeline="datafree", + log_dir=None, +) + +# # Confirm generations of the quantized model look the same +transformed_output = model(input_ids) + +print(f"Norm Diff {(orig_output-transformed_output).norm()}") +print(f"Norm {orig_output.norm()}, {transformed_output.norm()}") diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 55de99501..df815aa4f 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -125,7 +125,8 @@ def __init__( self.output_dir = output_dir # initialize the model and processor - pre_process(model_args) + # TODO Remove Comment before merge, this is just needed for DummyModel + # pre_process(model_args) # Set instance attributes self.model = self.model_args.model diff --git a/src/llmcompressor/modifiers/transform/presets/spinquant.py b/src/llmcompressor/modifiers/transform/presets/spinquant.py index 555b03fd6..d9765a6d5 100644 --- a/src/llmcompressor/modifiers/transform/presets/spinquant.py +++ b/src/llmcompressor/modifiers/transform/presets/spinquant.py @@ -9,43 +9,54 @@ type="hadamard", apply=[ TransformArgs( - targets=["re:.*embed_tokens$", "re:.*o_proj$", "re:.*down_proj$"], + targets=[ + # outermost rotation + "re:.*embed_tokens$", + # attention rotations + "re:.*o_proj$", + # mlp rotations + "re:.*down_proj$", + ], location="weight_output", ), TransformArgs( targets=[ + # outermost rotation + "lm_head", + # attention rotations "re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$", + # mlp rotations "re:.*up_proj$", "re:.*gate_proj$", - "lm_head", ], location="weight_input", inverse=True, ), ], ), - "R2": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["re:.*v_proj$"], - location="weight_output", - ), - TransformArgs( - targets=["re:.*o_proj$"], location="weight_input", inverse=True - ), - ], - ), + # "R2": TransformScheme( + # type="hadamard", + # # TODO infer head_dim from config.json in SpinQuantModifier + # head_dim=128, + # apply=[ + # TransformArgs(targets=["re:.*v_proj$"], location="weight_output"), + # TransformArgs( + # targets=["re:.*o_proj$"], + # location="weight_input", + # inverse=True, + # ), + # ], + # ), } ) # All rotations LLAMA_SPINQUANT = TransformConfig( config_groups={ - "R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"], - "R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"], + # "R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"], + # "R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"], "R3": TransformScheme( type="hadamard", apply=[ From bbcdc8ca6cd0c055e9baa543fe91fd0c10b88a11 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 9 Jul 2025 23:12:31 -0400 Subject: [PATCH 09/54] implement fuse_norm_linears Signed-off-by: Kyle Sayers --- src/llmcompressor/modeling/fuse.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 src/llmcompressor/modeling/fuse.py diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py new file mode 100644 index 000000000..4a9a34bb3 --- /dev/null +++ b/src/llmcompressor/modeling/fuse.py @@ -0,0 +1,28 @@ +from typing import Iterable + +import torch +from compressed_tensors import update_offload_parameter + +__all__ = ["fuse_norm_linears"] + + +def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]): + """ + Fuse a norm layer into subsequent linear layers. This useful for ensuring transform + invariance between norm and linear layers. + + Note that a model cannot be properly trained after its norms have been fused + + :param norm: norm layer whose weight will be fused into subsequent linears + :param linears: linear layers which directly follow the norm layer + """ + if isinstance(norm, torch.nn.RMSNorm): + for linear in linears: + # spinquant does this op in float64 + new_weight = linear.weight * norm.weight + update_offload_parameter(linear, "weight", new_weight) + + update_offload_parameter(norm, "weight", torch.ones_like(norm.weight)) + + else: + raise ValueError(f"Cannot fuse norm of type {type(norm)}") From f5c2150eefb3e87b1719ecc75b03de9a370bb94c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 11:07:36 -0400 Subject: [PATCH 10/54] R1 working Signed-off-by: Kyle Sayers --- src/llmcompressor/modeling/__init__.py | 1 + src/llmcompressor/modeling/fuse.py | 13 +++++++++---- src/llmcompressor/modifiers/transform/transform.py | 11 +++++++++-- src/llmcompressor/pipelines/data_free/pipeline.py | 5 +++++ 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py index e2c22ed1f..871955916 100644 --- a/src/llmcompressor/modeling/__init__.py +++ b/src/llmcompressor/modeling/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa from .prepare import * +from .fuse import * \ No newline at end of file diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index 4a9a34bb3..a87914a8b 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -1,7 +1,9 @@ from typing import Iterable import torch -from compressed_tensors import update_offload_parameter +from compressed_tensors import get_execution_device, align_module_device, update_offload_parameter + +from transformers.models.llama.modeling_llama import LlamaRMSNorm __all__ = ["fuse_norm_linears"] @@ -16,10 +18,13 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]) :param norm: norm layer whose weight will be fused into subsequent linears :param linears: linear layers which directly follow the norm layer """ - if isinstance(norm, torch.nn.RMSNorm): + if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm)): for linear in linears: - # spinquant does this op in float64 - new_weight = linear.weight * norm.weight + # NOTE: spinquant does this op in float64 + exec_device = get_execution_device(norm) + with align_module_device(norm, exec_device), align_module_device(linear, exec_device): + new_weight = linear.weight * norm.weight + update_offload_parameter(linear, "weight", new_weight) update_offload_parameter(norm, "weight", torch.ones_like(norm.weight)) diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py index e94a3dc35..3c59cde03 100644 --- a/src/llmcompressor/modifiers/transform/transform.py +++ b/src/llmcompressor/modifiers/transform/transform.py @@ -4,6 +4,7 @@ from pydantic import ValidationError, model_validator from llmcompressor.core import Event, EventType, State +from llmcompressor.modeling import fuse_norm_linears from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.transform.presets import TRANSFORM_PRESETS @@ -29,13 +30,19 @@ def validate_model_after(model: "TransformModifier") -> "TransformModifier": return model def on_initialize(self, state: State, **kwargs) -> bool: - apply_transform_config(state.model, self.config) - return True def on_start(self, state: State, event: Event, **kwargs): self.started_ = True + for layer in state.model.model.layers: + fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj)) + fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj)) + + # needs to happen after the model has been hooked to execute on the GPU + # otherwise we're applying weight transforms on CPU + apply_transform_config(state.model, self.config) + def on_event(self, state: State, event: Event, **kwargs): if event.type_ == EventType.CALIBRATION_EPOCH_START: if not self.started_: diff --git a/src/llmcompressor/pipelines/data_free/pipeline.py b/src/llmcompressor/pipelines/data_free/pipeline.py index 587f7ca69..7ad6d56dc 100644 --- a/src/llmcompressor/pipelines/data_free/pipeline.py +++ b/src/llmcompressor/pipelines/data_free/pipeline.py @@ -5,6 +5,7 @@ from llmcompressor.core.session_functions import LifecycleCallbacks from llmcompressor.pipelines.registry import CalibrationPipeline +from llmcompressor.utils.dev import dispatch_for_generation if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -27,5 +28,9 @@ def __call__( :param dataloader: loads data for calibration :param dataset_args: dataset arguments relevant to pipelines """ + # some ops are still performed on the model by modifiers + # we want those ops to occur on the GPU + dispatch_for_generation(model) + LifecycleCallbacks.calibration_epoch_start() LifecycleCallbacks.calibration_epoch_end() From dc5c30c54df8a19bc9c928e51c648b533c505d4a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 11:38:49 -0400 Subject: [PATCH 11/54] add r2, increase precision Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 1 - src/llmcompressor/modeling/fuse.py | 7 +- .../modifiers/transform/presets/spinquant.py | 70 +++++++++++++++---- 3 files changed, 63 insertions(+), 15 deletions(-) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 1ec7b6516..96d65b997 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -59,7 +59,6 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ - # TODO preset_config="LLAMA_SPINQUANT_R1R2" outputs gibberish # TODO preset_config="QUIP_ONLINE" outputs gibberish # preset_config="QUIP" output sensible, but cannot load saved # checkpoint or run evals (~4hrs to run) diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index a87914a8b..3e059f7cb 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -23,7 +23,12 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]) # NOTE: spinquant does this op in float64 exec_device = get_execution_device(norm) with align_module_device(norm, exec_device), align_module_device(linear, exec_device): - new_weight = linear.weight * norm.weight + + weight_dtype = linear.weight.dtype + + new_weight = linear.weight.to(torch.float64) * norm.weight.to(torch.float64) + + new_weight = new_weight.to(weight_dtype) update_offload_parameter(linear, "weight", new_weight) diff --git a/src/llmcompressor/modifiers/transform/presets/spinquant.py b/src/llmcompressor/modifiers/transform/presets/spinquant.py index d9765a6d5..62dfb2477 100644 --- a/src/llmcompressor/modifiers/transform/presets/spinquant.py +++ b/src/llmcompressor/modifiers/transform/presets/spinquant.py @@ -36,25 +36,69 @@ ), ], ), - # "R2": TransformScheme( - # type="hadamard", - # # TODO infer head_dim from config.json in SpinQuantModifier - # head_dim=128, - # apply=[ - # TransformArgs(targets=["re:.*v_proj$"], location="weight_output"), - # TransformArgs( - # targets=["re:.*o_proj$"], - # location="weight_input", - # inverse=True, - # ), - # ], - # ), + "R2": TransformScheme( + type="hadamard", + # TODO infer head_dim from config.json in SpinQuantModifier + head_dim=128, + apply=[ + TransformArgs(targets=["re:.*v_proj$"], location="weight_output"), + TransformArgs( + targets=["re:.*o_proj$"], + location="weight_input", + inverse=True, + ), + ], + ), } ) # All rotations LLAMA_SPINQUANT = TransformConfig( config_groups={ + "R1": TransformScheme( + type="hadamard", + apply=[ + TransformArgs( + targets=[ + # outermost rotation + "re:.*embed_tokens$", + # attention rotations + "re:.*o_proj$", + # mlp rotations + "re:.*down_proj$", + ], + location="weight_output", + ), + TransformArgs( + targets=[ + # outermost rotation + "lm_head", + # attention rotations + "re:.*q_proj$", + "re:.*k_proj$", + "re:.*v_proj$", + # mlp rotations + "re:.*up_proj$", + "re:.*gate_proj$", + ], + location="weight_input", + inverse=True, + ), + ], + ), + "R2": TransformScheme( + type="hadamard", + # TODO infer head_dim from config.json in SpinQuantModifier + head_dim=128, + apply=[ + TransformArgs(targets=["re:.*v_proj$"], location="weight_output"), + TransformArgs( + targets=["re:.*o_proj$"], + location="weight_input", + inverse=True, + ), + ], + ), # "R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"], # "R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"], "R3": TransformScheme( From 7172c2604f0301d05ec2be5cb4b1f58d49331d50 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 14:49:03 -0400 Subject: [PATCH 12/54] spinquant modifier Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 4 +- examples/transform/spinquant_dummy.py | 4 +- .../modifiers/transform/__init__.py | 3 +- .../modifiers/transform/presets/__init__.py | 9 - .../modifiers/transform/quip/base.py | 0 .../{presets/quip.py => quip/template.py} | 0 .../modifiers/transform/spinquant/__init__.py | 1 + .../modifiers/transform/spinquant/base.py | 215 ++++++++++++++++++ .../spinquant.py => spinquant/template.py} | 0 .../modifiers/transform/transform.py | 65 ------ 10 files changed, 221 insertions(+), 80 deletions(-) delete mode 100644 src/llmcompressor/modifiers/transform/presets/__init__.py create mode 100644 src/llmcompressor/modifiers/transform/quip/base.py rename src/llmcompressor/modifiers/transform/{presets/quip.py => quip/template.py} (100%) create mode 100644 src/llmcompressor/modifiers/transform/spinquant/__init__.py create mode 100644 src/llmcompressor/modifiers/transform/spinquant/base.py rename src/llmcompressor/modifiers/transform/{presets/spinquant.py => spinquant/template.py} (100%) delete mode 100644 src/llmcompressor/modifiers/transform/transform.py diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 96d65b997..8c87cb6a6 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -3,7 +3,7 @@ from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier -from llmcompressor.modifiers.transform import TransformModifier +from llmcompressor.modifiers.transform import SpinQuantModifier from llmcompressor.utils import dispatch_for_generation # Select model and load it. @@ -62,7 +62,7 @@ def tokenize(sample): # TODO preset_config="QUIP_ONLINE" outputs gibberish # preset_config="QUIP" output sensible, but cannot load saved # checkpoint or run evals (~4hrs to run) - TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"), + SpinQuantModifier(rotations=["R1", "R2"]), # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] diff --git a/examples/transform/spinquant_dummy.py b/examples/transform/spinquant_dummy.py index 494b6c611..3e8c9d483 100644 --- a/examples/transform/spinquant_dummy.py +++ b/examples/transform/spinquant_dummy.py @@ -4,7 +4,7 @@ from compressed_tensors.utils import update_parameter_data from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier -from llmcompressor.modifiers.transform import TransformModifier +from llmcompressor.modifiers.transform import SpinQuantModifier from llmcompressor.utils import dispatch_for_generation from transformers.models.llama.modeling_llama import ( LlamaRMSNorm, @@ -94,7 +94,7 @@ def forward(self, input_ids): recipe = [ # NOTE: preset_config="QUIP" output sensible, but cannot load saved # checkpoint or run evals (~4hrs to run) - TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"), + SpinQuantModifier(rotations=["R1", "R2"]), # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index c43958136..9956d0340 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,4 +1,3 @@ # flake8: noqa -from .presets import TRANSFORM_PRESETS -from .transform import TransformModifier +from .spinquant import SpinQuantModifier diff --git a/src/llmcompressor/modifiers/transform/presets/__init__.py b/src/llmcompressor/modifiers/transform/presets/__init__.py deleted file mode 100644 index 0d4a06b90..000000000 --- a/src/llmcompressor/modifiers/transform/presets/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .quip import QUIP, QUIP_ONLINE -from .spinquant import LLAMA_SPINQUANT, LLAMA_SPINQUANT_R1R2 - -TRANSFORM_PRESETS = { - "QUIP": QUIP, - "QUIP_ONLINE": QUIP_ONLINE, - "LLAMA_SPINQUANT": LLAMA_SPINQUANT, - "LLAMA_SPINQUANT_R1R2": LLAMA_SPINQUANT_R1R2, -} diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/llmcompressor/modifiers/transform/presets/quip.py b/src/llmcompressor/modifiers/transform/quip/template.py similarity index 100% rename from src/llmcompressor/modifiers/transform/presets/quip.py rename to src/llmcompressor/modifiers/transform/quip/template.py diff --git a/src/llmcompressor/modifiers/transform/spinquant/__init__.py b/src/llmcompressor/modifiers/transform/spinquant/__init__.py new file mode 100644 index 000000000..773cfc466 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/__init__.py @@ -0,0 +1 @@ +from .base import * \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py new file mode 100644 index 000000000..6d0c0cca3 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -0,0 +1,215 @@ +from typing import Optional, List, Literal + +from compressed_tensors.transform import TransformConfig, TransformScheme, TransformArgs, apply_transform_config +from pydantic import BaseModel, field_validator, Field + +from llmcompressor.core import Event, EventType, State +from llmcompressor.modeling import fuse_norm_linears +from llmcompressor.modifiers import Modifier +from enum import Enum + +from transformers import PreTrainedModel + + +class SpinQuantMappings(BaseModel): + embedding: str + + attn_q: str + attn_k: str + attn_v: str + attn_o: str + attn_head_dim: Optional[int] = Field(default=None) + + mlp_in: List[str] # up_proj, gate_proj + mlp_out: List[str] # down_proj + + lm_head: str + + @field_validator("mlp_in", "mlp_out", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + +class NormMapping(BaseModel): + norm: str + linears: List[str] + + @field_validator("linears", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + + + +llama_spinquant = SpinQuantMappings( + embedding="re:.*embed_tokens$", + + attn_q="re:.*q_proj$", + attn_k="re:.*k_proj$", + attn_v="re:.*v_proj$", + attn_o="re:.*o_proj$", + + mlp_in=["re:.*up_proj$", "re:.*gate_proj$"], + mlp_out="re:.*down_proj$", + + lm_head="lm_head", +) + +llama_norm_mappings = [ + NormMapping( + norm="re:.*input_layernorm$", + linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], + ), + NormMapping( + norm="re:.*post_attention_layernorm$", + linears=["re:.*up_proj$", "re:.*gate_proj$"], + ) +] + +class SpinquantRotation(Enum): + R1 = "R1" + R2 = "R2" + R3 = "R3" + R4 = "R4" + +class SpinQuantModifier(Modifier): + rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) + + transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(default="hadamard") + randomize: bool = Field(default=False) + learnable: bool = Field(default=False) + + mappings: Optional[SpinQuantMappings] = None + norm_mappings: Optional[List[NormMapping]] = None + + transform_config: Optional[TransformConfig] = None # optional override for more fine-grained control + + def on_initialize(self, state: State, **kwargs) -> bool: + # HARDCODE + self.mappings = llama_spinquant + self.norm_mappings = llama_norm_mappings + + if self.transform_config is not None: + if self.mappings is not None: + raise ValueError() + + return True + + config_groups = {} + for rotation in self.rotations: + if rotation == SpinquantRotation.R1: + config_groups["R1"] = self._create_r1_scheme() + + if rotation == SpinquantRotation.R2: + config_groups["R2"] = self._create_r2_scheme(state.model) + + if rotation == SpinquantRotation.R3: + config_groups["R3"] = self._create_r3_scheme() + + if rotation == SpinquantRotation.R4: + config_groups["R4"] = self._create_r4_scheme() + + self.transform_config = TransformConfig(config_groups=config_groups) + + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + for layer in state.model.model.layers: + fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj)) + fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj)) + + # needs to happen after the model has been hooked to execute on the GPU + # otherwise we're applying weight transforms on CPU + apply_transform_config(state.model, self.transform_config) + + + + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + pass + + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + + def on_finalize(self, state: State, **kwargs) -> bool: + if not self.ended_: + self.on_end(state, None) + + return True + + + def _create_r1_scheme(self) -> TransformScheme: + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + apply=[ + TransformArgs( + targets=[ + self.mappings.embedding, + self.mappings.attn_o, + *self.mappings.mlp_out, + ], + location="weight_output", + ), + TransformArgs( + targets=[ + self.mappings.attn_q, + self.mappings.attn_k, + self.mappings.attn_v, + *self.mappings.mlp_in, + self.mappings.lm_head + ], + location="weight_input", + inverse=True, + ), + ] + ) + + def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: + config = model.config + + if hasattr(config, "head_dim"): + head_dim = config.head_dim + elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + head_dim = config.hidden_size // config.num_attention_heads + else: + raise NotImplementedError() + + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + head_dim=head_dim, + apply=[ + TransformArgs(targets=[self.mappings.attn_v], location="weight_output"), + TransformArgs( + targets=[self.mappings.attn_o], + location="weight_input", + inverse=True, + ), + ], + ) + + + def _create_r3_scheme(self) -> TransformScheme: + raise NotImplementedError() + + + def _create_r4_scheme(self) -> TransformScheme: + raise NotImplementedError() \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/presets/spinquant.py b/src/llmcompressor/modifiers/transform/spinquant/template.py similarity index 100% rename from src/llmcompressor/modifiers/transform/presets/spinquant.py rename to src/llmcompressor/modifiers/transform/spinquant/template.py diff --git a/src/llmcompressor/modifiers/transform/transform.py b/src/llmcompressor/modifiers/transform/transform.py deleted file mode 100644 index 3c59cde03..000000000 --- a/src/llmcompressor/modifiers/transform/transform.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Optional - -from compressed_tensors.transform import TransformConfig, apply_transform_config -from pydantic import ValidationError, model_validator - -from llmcompressor.core import Event, EventType, State -from llmcompressor.modeling import fuse_norm_linears -from llmcompressor.modifiers import Modifier -from llmcompressor.modifiers.transform.presets import TRANSFORM_PRESETS - - -class TransformModifier(Modifier): - preset_config: Optional[str] = None - config: Optional[TransformConfig] = None - - # model validator to validate both preset and config groups are not provided - @model_validator(mode="after") - def validate_model_after(model: "TransformModifier") -> "TransformModifier": - if model.preset_config is None and model.config is None: - raise ValidationError("Either a config or a preset_config must be provided") - - if model.preset_config is not None: - if model.preset_config not in TRANSFORM_PRESETS: - raise ValidationError( - f"Invalid preset_config '{model.preset_config}' " - f"must be in {TRANSFORM_PRESETS.keys()}" - ) - model.config = TRANSFORM_PRESETS[model.preset_config] - - return model - - def on_initialize(self, state: State, **kwargs) -> bool: - return True - - def on_start(self, state: State, event: Event, **kwargs): - self.started_ = True - - for layer in state.model.model.layers: - fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj)) - fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj)) - - # needs to happen after the model has been hooked to execute on the GPU - # otherwise we're applying weight transforms on CPU - apply_transform_config(state.model, self.config) - - def on_event(self, state: State, event: Event, **kwargs): - if event.type_ == EventType.CALIBRATION_EPOCH_START: - if not self.started_: - self.on_start(state, None) - - elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: - pass - - elif event.type_ == EventType.CALIBRATION_EPOCH_END: - if not self.ended_: - self.on_end(state, None) - - def on_end(self, state: State, event: Event, **kwargs): - self.ended_ = True - - def on_finalize(self, state: State, **kwargs) -> bool: - if not self.ended_: - self.on_end(state, None) - - return True From 9298e8268d8c6c11ebd7b6c4dd0a433c639f0971 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 14:50:00 -0400 Subject: [PATCH 13/54] remove space Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/transform/spinquant/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 6d0c0cca3..8bf2e5cb1 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -78,7 +78,6 @@ class SpinquantRotation(Enum): class SpinQuantModifier(Modifier): rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) - transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(default="hadamard") randomize: bool = Field(default=False) learnable: bool = Field(default=False) From f77226d12b8e6e7e5556f70b58b392d1b97d2025 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 14:51:20 -0400 Subject: [PATCH 14/54] use iterable Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/transform/spinquant/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 8bf2e5cb1..76f38361c 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Literal +from typing import Optional, List, Literal, Iterable from compressed_tensors.transform import TransformConfig, TransformScheme, TransformArgs, apply_transform_config from pydantic import BaseModel, field_validator, Field @@ -77,7 +77,7 @@ class SpinquantRotation(Enum): R4 = "R4" class SpinQuantModifier(Modifier): - rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) + rotations: Iterable[SpinquantRotation] = ("R1", "R2") transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(default="hadamard") randomize: bool = Field(default=False) learnable: bool = Field(default=False) From fdb64b54876f81c5e34fe020840334bc616ba6d6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 14:58:24 -0400 Subject: [PATCH 15/54] add rotation validation Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/transform/spinquant/base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 76f38361c..8e786fb7e 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -87,6 +87,12 @@ class SpinQuantModifier(Modifier): transform_config: Optional[TransformConfig] = None # optional override for more fine-grained control + @field_validator("rotations", mode="before") + def validate_rotations(cls, value): + if isinstance(value, Iterable): + return tuple(v.upper() for v in value) + return value + def on_initialize(self, state: State, **kwargs) -> bool: # HARDCODE self.mappings = llama_spinquant From 5daa2d5a0cb31f32911942762a51c4ea69822f48 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 11 Jul 2025 15:11:54 -0400 Subject: [PATCH 16/54] embedding fusion Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/base.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 8e786fb7e..813e1335a 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -125,6 +125,18 @@ def on_initialize(self, state: State, **kwargs) -> bool: def on_start(self, state: State, event: Event, **kwargs): self.started_ = True + # TODO: use norm mappings + # Embedding fusion + # theoretically, doesn't do anything. Doesn't seem to help model sanity either + from compressed_tensors import update_offload_parameter + for W in [state.model.model.embed_tokens]: + W_ = W.weight.data.double() + W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype) + + update_offload_parameter(state.model.model.embed_tokens, "weight", W.weight) + + # TODO: use norm mappings + # layer norm fusion for layer in state.model.model.layers: fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj)) fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj)) From 0e9af7b6d1ff8d574c373b741aeeaf3733b4ee47 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 12 Jul 2025 10:38:43 -0400 Subject: [PATCH 17/54] add missing norm fusion Signed-off-by: Kyle Sayers --- examples/transform/spinquant_dummy.py | 9 +-- src/llmcompressor/modeling/__init__.py | 2 +- src/llmcompressor/modeling/fuse.py | 18 +++-- .../modifiers/transform/spinquant/__init__.py | 2 +- .../modifiers/transform/spinquant/base.py | 75 +++++++++++-------- 5 files changed, 62 insertions(+), 44 deletions(-) diff --git a/examples/transform/spinquant_dummy.py b/examples/transform/spinquant_dummy.py index 3e8c9d483..71db967de 100644 --- a/examples/transform/spinquant_dummy.py +++ b/examples/transform/spinquant_dummy.py @@ -1,14 +1,13 @@ -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer import torch from compressed_tensors.utils import update_parameter_data +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.models.llama.modeling_llama import LlamaRMSNorm + from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier from llmcompressor.modifiers.transform import SpinQuantModifier from llmcompressor.utils import dispatch_for_generation -from transformers.models.llama.modeling_llama import ( - LlamaRMSNorm, -) hidden_dim = intermediate_dim = 64 up_dim = 128 diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py index 871955916..76b6b0391 100644 --- a/src/llmcompressor/modeling/__init__.py +++ b/src/llmcompressor/modeling/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa +from .fuse import * from .prepare import * -from .fuse import * \ No newline at end of file diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index 3e059f7cb..cb88ecc22 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -1,8 +1,11 @@ from typing import Iterable import torch -from compressed_tensors import get_execution_device, align_module_device, update_offload_parameter - +from compressed_tensors import ( + align_module_device, + get_execution_device, + update_offload_parameter, +) from transformers.models.llama.modeling_llama import LlamaRMSNorm __all__ = ["fuse_norm_linears"] @@ -22,14 +25,17 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]) for linear in linears: # NOTE: spinquant does this op in float64 exec_device = get_execution_device(norm) - with align_module_device(norm, exec_device), align_module_device(linear, exec_device): - + with align_module_device(norm, exec_device), align_module_device( + linear, exec_device + ): weight_dtype = linear.weight.dtype - new_weight = linear.weight.to(torch.float64) * norm.weight.to(torch.float64) + new_weight = linear.weight.to(torch.float64) * norm.weight.to( + torch.float64 + ) new_weight = new_weight.to(weight_dtype) - + update_offload_parameter(linear, "weight", new_weight) update_offload_parameter(norm, "weight", torch.ones_like(norm.weight)) diff --git a/src/llmcompressor/modifiers/transform/spinquant/__init__.py b/src/llmcompressor/modifiers/transform/spinquant/__init__.py index 773cfc466..9b5ed21c9 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/__init__.py +++ b/src/llmcompressor/modifiers/transform/spinquant/__init__.py @@ -1 +1 @@ -from .base import * \ No newline at end of file +from .base import * diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 813e1335a..31b1bbdee 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -1,14 +1,18 @@ -from typing import Optional, List, Literal, Iterable +from enum import Enum +from typing import Iterable, List, Literal, Optional -from compressed_tensors.transform import TransformConfig, TransformScheme, TransformArgs, apply_transform_config -from pydantic import BaseModel, field_validator, Field +from compressed_tensors.transform import ( + TransformArgs, + TransformConfig, + TransformScheme, + apply_transform_config, +) +from pydantic import BaseModel, Field, field_validator +from transformers import PreTrainedModel from llmcompressor.core import Event, EventType, State from llmcompressor.modeling import fuse_norm_linears from llmcompressor.modifiers import Modifier -from enum import Enum - -from transformers import PreTrainedModel class SpinQuantMappings(BaseModel): @@ -29,9 +33,10 @@ class SpinQuantMappings(BaseModel): def cast_to_list(cls, value): if isinstance(value, str): return [value] - + return value - + + class NormMapping(BaseModel): norm: str linears: List[str] @@ -40,22 +45,18 @@ class NormMapping(BaseModel): def cast_to_list(cls, value): if isinstance(value, str): return [value] - - return value + return value llama_spinquant = SpinQuantMappings( embedding="re:.*embed_tokens$", - attn_q="re:.*q_proj$", attn_k="re:.*k_proj$", attn_v="re:.*v_proj$", attn_o="re:.*o_proj$", - mlp_in=["re:.*up_proj$", "re:.*gate_proj$"], mlp_out="re:.*down_proj$", - lm_head="lm_head", ) @@ -67,25 +68,31 @@ def cast_to_list(cls, value): NormMapping( norm="re:.*post_attention_layernorm$", linears=["re:.*up_proj$", "re:.*gate_proj$"], - ) + ), ] + class SpinquantRotation(Enum): R1 = "R1" R2 = "R2" R3 = "R3" R4 = "R4" + class SpinQuantModifier(Modifier): rotations: Iterable[SpinquantRotation] = ("R1", "R2") - transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(default="hadamard") + transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( + default="hadamard" + ) randomize: bool = Field(default=False) learnable: bool = Field(default=False) mappings: Optional[SpinQuantMappings] = None norm_mappings: Optional[List[NormMapping]] = None - - transform_config: Optional[TransformConfig] = None # optional override for more fine-grained control + + transform_config: Optional[TransformConfig] = ( + None # optional override for more fine-grained control + ) @field_validator("rotations", mode="before") def validate_rotations(cls, value): @@ -101,7 +108,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: if self.transform_config is not None: if self.mappings is not None: raise ValueError() - + return True config_groups = {} @@ -129,6 +136,7 @@ def on_start(self, state: State, event: Event, **kwargs): # Embedding fusion # theoretically, doesn't do anything. Doesn't seem to help model sanity either from compressed_tensors import update_offload_parameter + for W in [state.model.model.embed_tokens]: W_ = W.weight.data.double() W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype) @@ -138,16 +146,24 @@ def on_start(self, state: State, event: Event, **kwargs): # TODO: use norm mappings # layer norm fusion for layer in state.model.model.layers: - fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj)) - fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj)) + fuse_norm_linears( + layer.input_layernorm, + ( + layer.self_attn.q_proj, + layer.self_attn.k_proj, + layer.self_attn.v_proj, + ), + ) + fuse_norm_linears( + layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj) + ) + + fuse_norm_linears(state.model.model.norm, (state.model.lm_head,)) # needs to happen after the model has been hooked to execute on the GPU # otherwise we're applying weight transforms on CPU apply_transform_config(state.model, self.transform_config) - - - def on_event(self, state: State, event: Event, **kwargs): if event.type_ == EventType.CALIBRATION_EPOCH_START: if not self.started_: @@ -169,7 +185,6 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True - def _create_r1_scheme(self) -> TransformScheme: return TransformScheme( type=self.transform_type, @@ -190,14 +205,14 @@ def _create_r1_scheme(self) -> TransformScheme: self.mappings.attn_k, self.mappings.attn_v, *self.mappings.mlp_in, - self.mappings.lm_head + self.mappings.lm_head, ], location="weight_input", inverse=True, ), - ] + ], ) - + def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: config = model.config @@ -207,7 +222,7 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: head_dim = config.hidden_size // config.num_attention_heads else: raise NotImplementedError() - + return TransformScheme( type=self.transform_type, randomize=self.randomize, @@ -223,10 +238,8 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: ], ) - def _create_r3_scheme(self) -> TransformScheme: raise NotImplementedError() - def _create_r4_scheme(self) -> TransformScheme: - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() From fce83be83f5c4ec01b1717263c1a6effcacf3e8d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 12 Jul 2025 12:40:27 -0400 Subject: [PATCH 18/54] use norm mappings Signed-off-by: Kyle Sayers --- src/llmcompressor/modeling/fuse.py | 25 ++++++-- .../modifiers/transform/spinquant/base.py | 64 ++++++++++--------- 2 files changed, 54 insertions(+), 35 deletions(-) diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index cb88ecc22..12e21f14b 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -8,7 +8,24 @@ ) from transformers.models.llama.modeling_llama import LlamaRMSNorm -__all__ = ["fuse_norm_linears"] +__all__ = ["normalize_embedding", "fuse_norm_linears"] + + +PRECISION = torch.float64 + + +def normalize_embedding(embedding: torch.nn.Module): + if isinstance(embedding, (torch.nn.Embedding)): + with align_module_device(embedding): + weight_dtype = embedding.weight.dtype + weight = embedding.weight.to(PRECISION) + new_weight = weight - weight.mean(dim=-1, keepdim=True) + new_weight = new_weight.to(weight_dtype) + + update_offload_parameter(embedding, "weight", new_weight) + + else: + raise ValueError(f"Cannot normalize embedding of type {type(embedding)}") def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]): @@ -29,11 +46,7 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]) linear, exec_device ): weight_dtype = linear.weight.dtype - - new_weight = linear.weight.to(torch.float64) * norm.weight.to( - torch.float64 - ) - + new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION) new_weight = new_weight.to(weight_dtype) update_offload_parameter(linear, "weight", new_weight) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 31b1bbdee..c6b1c3087 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -1,6 +1,7 @@ from enum import Enum from typing import Iterable, List, Literal, Optional +from compressed_tensors import match_named_modules, is_match from compressed_tensors.transform import ( TransformArgs, TransformConfig, @@ -11,7 +12,7 @@ from transformers import PreTrainedModel from llmcompressor.core import Event, EventType, State -from llmcompressor.modeling import fuse_norm_linears +from llmcompressor.modeling import normalize_embedding, fuse_norm_linears from llmcompressor.modifiers import Modifier @@ -69,6 +70,10 @@ def cast_to_list(cls, value): norm="re:.*post_attention_layernorm$", linears=["re:.*up_proj$", "re:.*gate_proj$"], ), + NormMapping( + norm="model.norm", + linears=["lm_head"], + ), ] @@ -132,36 +137,10 @@ def on_initialize(self, state: State, **kwargs) -> bool: def on_start(self, state: State, event: Event, **kwargs): self.started_ = True - # TODO: use norm mappings - # Embedding fusion - # theoretically, doesn't do anything. Doesn't seem to help model sanity either - from compressed_tensors import update_offload_parameter - - for W in [state.model.model.embed_tokens]: - W_ = W.weight.data.double() - W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype) - - update_offload_parameter(state.model.model.embed_tokens, "weight", W.weight) - - # TODO: use norm mappings - # layer norm fusion - for layer in state.model.model.layers: - fuse_norm_linears( - layer.input_layernorm, - ( - layer.self_attn.q_proj, - layer.self_attn.k_proj, - layer.self_attn.v_proj, - ), - ) - fuse_norm_linears( - layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj) - ) - - fuse_norm_linears(state.model.model.norm, (state.model.lm_head,)) - # needs to happen after the model has been hooked to execute on the GPU # otherwise we're applying weight transforms on CPU + self._prenormalize_embeddings(state.model) + self._fuse_norms(state.model) apply_transform_config(state.model, self.transform_config) def on_event(self, state: State, event: Event, **kwargs): @@ -185,6 +164,33 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True + def _prenormalize_embeddings(self, model: PreTrainedModel): + for _, embedding in match_named_modules( + model, [self.mappings.embedding], warn_on_fail=True + ): + normalize_embedding(embedding) + + def _fuse_norms(self, model: PreTrainedModel): + for mapping in self.norm_mappings: + targets = (mapping.norm, *mapping.linears) + matches = dict() + + for name, module in model.named_modules(): + # match until we get a full set + for target in targets: + if is_match(name, module, target): + if target in matches: + raise ValueError("Cannot match twice") + matches[target] = module + + # once we have a full set, fuse and reset + if all(target in matches for target in targets): + fuse_norm_linears( + matches[mapping.norm], + (matches[target] for target in mapping.linears), + ) + matches = dict() + def _create_r1_scheme(self) -> TransformScheme: return TransformScheme( type=self.transform_type, From a979f8aff43ef81322d4b8934d03cb61fe65d360 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 12 Jul 2025 13:02:09 -0400 Subject: [PATCH 19/54] break into separate files Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/__init__.py | 2 + .../modifiers/transform/spinquant/base.py | 83 ++++--------------- .../modifiers/transform/spinquant/mappings.py | 42 ++++++++++ .../transform/spinquant/norm_mappings.py | 35 ++++++++ 4 files changed, 93 insertions(+), 69 deletions(-) create mode 100644 src/llmcompressor/modifiers/transform/spinquant/mappings.py create mode 100644 src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py diff --git a/src/llmcompressor/modifiers/transform/spinquant/__init__.py b/src/llmcompressor/modifiers/transform/spinquant/__init__.py index 9b5ed21c9..8bdc93d14 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/__init__.py +++ b/src/llmcompressor/modifiers/transform/spinquant/__init__.py @@ -1 +1,3 @@ +# flake8: noqa + from .base import * diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index c6b1c3087..7c76aeca5 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -1,80 +1,22 @@ from enum import Enum from typing import Iterable, List, Literal, Optional -from compressed_tensors import match_named_modules, is_match +from compressed_tensors import is_match, match_named_modules from compressed_tensors.transform import ( TransformArgs, TransformConfig, TransformScheme, apply_transform_config, ) -from pydantic import BaseModel, Field, field_validator +from pydantic import Field, field_validator from transformers import PreTrainedModel from llmcompressor.core import Event, EventType, State -from llmcompressor.modeling import normalize_embedding, fuse_norm_linears +from llmcompressor.modeling import fuse_norm_linears, normalize_embedding from llmcompressor.modifiers import Modifier - -class SpinQuantMappings(BaseModel): - embedding: str - - attn_q: str - attn_k: str - attn_v: str - attn_o: str - attn_head_dim: Optional[int] = Field(default=None) - - mlp_in: List[str] # up_proj, gate_proj - mlp_out: List[str] # down_proj - - lm_head: str - - @field_validator("mlp_in", "mlp_out", mode="before") - def cast_to_list(cls, value): - if isinstance(value, str): - return [value] - - return value - - -class NormMapping(BaseModel): - norm: str - linears: List[str] - - @field_validator("linears", mode="before") - def cast_to_list(cls, value): - if isinstance(value, str): - return [value] - - return value - - -llama_spinquant = SpinQuantMappings( - embedding="re:.*embed_tokens$", - attn_q="re:.*q_proj$", - attn_k="re:.*k_proj$", - attn_v="re:.*v_proj$", - attn_o="re:.*o_proj$", - mlp_in=["re:.*up_proj$", "re:.*gate_proj$"], - mlp_out="re:.*down_proj$", - lm_head="lm_head", -) - -llama_norm_mappings = [ - NormMapping( - norm="re:.*input_layernorm$", - linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], - ), - NormMapping( - norm="re:.*post_attention_layernorm$", - linears=["re:.*up_proj$", "re:.*gate_proj$"], - ), - NormMapping( - norm="model.norm", - linears=["lm_head"], - ), -] +from .mappings import SPINQUANT_MAPPING_REGISTRY, SpinQuantMappings +from .norm_mappings import NORM_MAPPING_REGISTRY, NormMapping class SpinquantRotation(Enum): @@ -92,12 +34,15 @@ class SpinQuantModifier(Modifier): randomize: bool = Field(default=False) learnable: bool = Field(default=False) + # norm mappings separate from spinquant mappings to allow users to + # override spinquant mappings with transform_config without overriding norms + # we can combine these mappings, but it requires some more validation logic + # maybe there's a reason to keep if other modifiers want norm fusing, idk mappings: Optional[SpinQuantMappings] = None norm_mappings: Optional[List[NormMapping]] = None - transform_config: Optional[TransformConfig] = ( - None # optional override for more fine-grained control - ) + # optional override for more fine-grained control + transform_config: Optional[TransformConfig] = None @field_validator("rotations", mode="before") def validate_rotations(cls, value): @@ -106,9 +51,9 @@ def validate_rotations(cls, value): return value def on_initialize(self, state: State, **kwargs) -> bool: - # HARDCODE - self.mappings = llama_spinquant - self.norm_mappings = llama_norm_mappings + # TODO: more validation + self.mappings = SPINQUANT_MAPPING_REGISTRY[state.model.__class__.__name__] + self.norm_mappings = NORM_MAPPING_REGISTRY[state.model.__class__.__name__] if self.transform_config is not None: if self.mappings is not None: diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py new file mode 100644 index 000000000..acf692d22 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -0,0 +1,42 @@ +from typing import Dict, List, Optional + +from pydantic import BaseModel, Field, field_validator + + +class SpinQuantMappings(BaseModel): + embedding: str + + attn_q: str + attn_k: str + attn_v: str + attn_o: str + attn_head_dim: Optional[int] = Field(default=None) + + mlp_in: List[str] # up_proj, gate_proj + mlp_out: List[str] # down_proj + + lm_head: str + + @field_validator("mlp_in", "mlp_out", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + + +_default_mappings = SpinQuantMappings( + embedding="re:.*embed_tokens$", + attn_q="re:.*q_proj$", + attn_k="re:.*k_proj$", + attn_v="re:.*v_proj$", + attn_o="re:.*o_proj$", + mlp_in=["re:.*up_proj$", "re:.*gate_proj$"], + mlp_out="re:.*down_proj$", + lm_head="lm_head", +) + + +SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMappings] = { + "LlamaForCausalLM": _default_mappings, +} diff --git a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py new file mode 100644 index 000000000..cefb987ca --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py @@ -0,0 +1,35 @@ +from typing import Dict, List + +from pydantic import BaseModel, field_validator + + +class NormMapping(BaseModel): + norm: str + linears: List[str] + + @field_validator("linears", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + + +_default_norm_mappings = [ + NormMapping( + norm="re:.*input_layernorm$", + linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], + ), + NormMapping( + norm="re:.*post_attention_layernorm$", + linears=["re:.*up_proj$", "re:.*gate_proj$"], + ), + NormMapping( + norm="model.norm", + linears=["lm_head"], + ), +] + +NORM_MAPPING_REGISTRY: Dict[str, NormMapping] = { + "LlamaForCausalLM": _default_norm_mappings, +} From 4cab29ef7060e6f67c43881fa44adeae2a0c4258 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 12 Jul 2025 14:37:34 -0400 Subject: [PATCH 20/54] small cleanup Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/base.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 7c76aeca5..e448bd372 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -62,18 +62,17 @@ def on_initialize(self, state: State, **kwargs) -> bool: return True config_groups = {} - for rotation in self.rotations: - if rotation == SpinquantRotation.R1: - config_groups["R1"] = self._create_r1_scheme() + if SpinquantRotation.R1 in self.rotations: + config_groups["R1"] = self._create_r1_scheme() - if rotation == SpinquantRotation.R2: - config_groups["R2"] = self._create_r2_scheme(state.model) + if SpinquantRotation.R2 in self.rotations: + config_groups["R2"] = self._create_r2_scheme(state.model) - if rotation == SpinquantRotation.R3: - config_groups["R3"] = self._create_r3_scheme() + if SpinquantRotation.R3 in self.rotations: + config_groups["R3"] = self._create_r3_scheme() - if rotation == SpinquantRotation.R4: - config_groups["R4"] = self._create_r4_scheme() + if SpinquantRotation.R4 in self.rotations: + config_groups["R4"] = self._create_r4_scheme() self.transform_config = TransformConfig(config_groups=config_groups) From f1cc987c00163705b46e5ad286a0e87732196323 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 14 Jul 2025 15:40:18 -0400 Subject: [PATCH 21/54] cleanup Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 32 ++++----- src/llmcompressor/entrypoints/oneshot.py | 3 +- .../modifiers/transform/test_dummy_model.py | 70 +++++++++---------- 3 files changed, 47 insertions(+), 58 deletions(-) rename examples/transform/spinquant_dummy.py => tests/llmcompressor/modifiers/transform/test_dummy_model.py (68%) diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 8c87cb6a6..790619b08 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -2,13 +2,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier +from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.modifiers.transform import SpinQuantModifier from llmcompressor.utils import dispatch_for_generation # Select model and load it. -# MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" -# MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" # TODO hidden size 3072 causes failure when creating hadamard MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( @@ -57,36 +55,32 @@ def tokenize(sample): ds = ds.map(tokenize, remove_columns=ds.column_names) # Configure the quantization algorithm to run. +# * apply spinquant transforms to model in order to make quantization easier # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ - # TODO preset_config="QUIP_ONLINE" outputs gibberish - # preset_config="QUIP" output sensible, but cannot load saved - # checkpoint or run evals (~4hrs to run) - SpinQuantModifier(rotations=["R1", "R2"]), - # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), + SpinQuantModifier(rotations=["R1", "R2"], transform_type="random-hadamard"), + QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] # Apply algorithms. oneshot( model=model, recipe=recipe, - # dataset=ds, - pipeline="datafree", - # max_seq_length=MAX_SEQUENCE_LENGTH, - # num_calibration_samples=NUM_CALIBRATION_SAMPLES, - log_dir=None, + dataset=ds, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) -# # Confirm generations of the quantized model look sane. +# Confirm generations of the quantized model look sane. print("\n\n") print("========== SAMPLE GENERATION ==============") dispatch_for_generation(model) input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") output = model.generate(input_ids, max_new_tokens=100) print(tokenizer.decode(output[0])) -# print("==========================================\n\n") +print("==========================================\n\n") -# # Save to disk compressed. -# SAVE_DIR = MODEL_ID.split("/")[1] + "-transform-quant-w4a16" -# model.save_pretrained(SAVE_DIR, save_compressed=True) -# tokenizer.save_pretrained(SAVE_DIR) +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-transformed-w4a16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index cfd3b551f..9219f21fb 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -125,8 +125,7 @@ def __init__( self.output_dir = output_dir # initialize the model and processor - # TODO Remove Comment before merge, this is just needed for DummyModel - # pre_process(model_args) + pre_process(model_args) # Set instance attributes self.model = self.model_args.model diff --git a/examples/transform/spinquant_dummy.py b/tests/llmcompressor/modifiers/transform/test_dummy_model.py similarity index 68% rename from examples/transform/spinquant_dummy.py rename to tests/llmcompressor/modifiers/transform/test_dummy_model.py index 71db967de..020a61e99 100644 --- a/examples/transform/spinquant_dummy.py +++ b/tests/llmcompressor/modifiers/transform/test_dummy_model.py @@ -14,9 +14,6 @@ num_embeddings = 12 -# TODO remove file before merging - - class DummySelfAttn(torch.nn.Module): def __init__(self, hidden_dim, intermediate_dim): super().__init__() @@ -75,37 +72,36 @@ def forward(self, input_ids): return self.lm_head(x) -model = DummyModel(num_embeddings, hidden_dim, intermediate_dim, up_dim) - -# TODO Uncomment this to see norm diff > 1e-6 -# This is due to issue Kyle spotted in https://arxiv.org/pdf/2405.16406 Page 5 Footnote 2 -# Will have to fuse layernorms with subsequent layers so that input_layernorm.weight is equal to torch.ones() (this apparently makes it rotation invariant) -# https://github.com/facebookresearch/SpinQuant/blob/8f47aa3f00e8662caf1a484153920a07e5281c3a/utils/fuse_norm_utils.py#L39 -# update_parameter_data( -# model.input_layernorm, -# torch.rand(model.input_layernorm.weight.shape), -# "weight", -# ) - -input_ids = torch.IntTensor([1, 2, 3, 4, 5]) -orig_output = model(input_ids) - -recipe = [ - # NOTE: preset_config="QUIP" output sensible, but cannot load saved - # checkpoint or run evals (~4hrs to run) - SpinQuantModifier(rotations=["R1", "R2"]), - # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), -] - -oneshot( - model=model, - recipe=recipe, - pipeline="datafree", - log_dir=None, -) - -# # Confirm generations of the quantized model look the same -transformed_output = model(input_ids) - -print(f"Norm Diff {(orig_output-transformed_output).norm()}") -print(f"Norm {orig_output.norm()}, {transformed_output.norm()}") +def test_dummy_model(): + model = DummyModel(num_embeddings, hidden_dim, intermediate_dim, up_dim) + + # TODO Uncomment this to see norm diff > 1e-6 + # This is due to issue Kyle spotted in https://arxiv.org/pdf/2405.16406 Page 5 Footnote 2 + # Will have to fuse layernorms with subsequent layers so that input_layernorm.weight is equal to torch.ones() (this apparently makes it rotation invariant) + # https://github.com/facebookresearch/SpinQuant/blob/8f47aa3f00e8662caf1a484153920a07e5281c3a/utils/fuse_norm_utils.py#L39 + # update_parameter_data( + # model.input_layernorm, + # torch.rand(model.input_layernorm.weight.shape), + # "weight", + # ) + + input_ids = torch.IntTensor([1, 2, 3, 4, 5]) + orig_output = model(input_ids) + + recipe = [ + SpinQuantModifier(rotations=["R1", "R2"]), + ] + + # TODO: work around preprocessing? + oneshot( + model=model, + recipe=recipe, + pipeline="datafree", + log_dir=None, + ) + + # # Confirm generations of the quantized model look the same + transformed_output = model(input_ids) + + print(f"Norm Diff {(orig_output-transformed_output).norm()}") + print(f"Norm {orig_output.norm()}, {transformed_output.norm()}") From a7bb2e2872cca3421e877de62bcb8a195f63a223 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 14 Jul 2025 15:40:55 -0400 Subject: [PATCH 22/54] more cleanup Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/template.py | 130 ------------------ 1 file changed, 130 deletions(-) delete mode 100644 src/llmcompressor/modifiers/transform/spinquant/template.py diff --git a/src/llmcompressor/modifiers/transform/spinquant/template.py b/src/llmcompressor/modifiers/transform/spinquant/template.py deleted file mode 100644 index 62dfb2477..000000000 --- a/src/llmcompressor/modifiers/transform/spinquant/template.py +++ /dev/null @@ -1,130 +0,0 @@ -from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme - -# Ref: https://arxiv.org/pdf/2405.16406 Fig 1 - -# Mergeable rotations R1 and R2 only -LLAMA_SPINQUANT_R1R2 = TransformConfig( - config_groups={ - "R1": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=[ - # outermost rotation - "re:.*embed_tokens$", - # attention rotations - "re:.*o_proj$", - # mlp rotations - "re:.*down_proj$", - ], - location="weight_output", - ), - TransformArgs( - targets=[ - # outermost rotation - "lm_head", - # attention rotations - "re:.*q_proj$", - "re:.*k_proj$", - "re:.*v_proj$", - # mlp rotations - "re:.*up_proj$", - "re:.*gate_proj$", - ], - location="weight_input", - inverse=True, - ), - ], - ), - "R2": TransformScheme( - type="hadamard", - # TODO infer head_dim from config.json in SpinQuantModifier - head_dim=128, - apply=[ - TransformArgs(targets=["re:.*v_proj$"], location="weight_output"), - TransformArgs( - targets=["re:.*o_proj$"], - location="weight_input", - inverse=True, - ), - ], - ), - } -) - -# All rotations -LLAMA_SPINQUANT = TransformConfig( - config_groups={ - "R1": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=[ - # outermost rotation - "re:.*embed_tokens$", - # attention rotations - "re:.*o_proj$", - # mlp rotations - "re:.*down_proj$", - ], - location="weight_output", - ), - TransformArgs( - targets=[ - # outermost rotation - "lm_head", - # attention rotations - "re:.*q_proj$", - "re:.*k_proj$", - "re:.*v_proj$", - # mlp rotations - "re:.*up_proj$", - "re:.*gate_proj$", - ], - location="weight_input", - inverse=True, - ), - ], - ), - "R2": TransformScheme( - type="hadamard", - # TODO infer head_dim from config.json in SpinQuantModifier - head_dim=128, - apply=[ - TransformArgs(targets=["re:.*v_proj$"], location="weight_output"), - TransformArgs( - targets=["re:.*o_proj$"], - location="weight_input", - inverse=True, - ), - ], - ), - # "R1": LLAMA_SPINQUANT_R1R2.config_groups["R1"], - # "R2": LLAMA_SPINQUANT_R1R2.config_groups["R2"], - "R3": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["re:.*self_attn$"], - location="k_cache", - ), - TransformArgs( - targets=["re:.*self_attn$"], - location="q_attn", - ), - ], - ), - "R4": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["re:.*down_proj$"], - location="input", - ), - TransformArgs( - targets=["re:.*down_proj$"], location="weight_input", inverse=True - ), - ], - ), - } -) From 0cf0188987898587c6f5d96a53b97264c8ee0435 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 14 Jul 2025 15:48:49 -0400 Subject: [PATCH 23/54] make new weight on cpu Signed-off-by: Kyle Sayers --- src/llmcompressor/modeling/fuse.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index 12e21f14b..33e91601c 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -51,7 +51,8 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]) update_offload_parameter(linear, "weight", new_weight) - update_offload_parameter(norm, "weight", torch.ones_like(norm.weight)) + new_norm_weight = torch.ones_like(norm.weight, device="cpu") + update_offload_parameter(norm, "weight", new_norm_weight) else: raise ValueError(f"Cannot fuse norm of type {type(norm)}") From 53ea3076161f8562fe7653f9f6cb57c48da75ae4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 14 Jul 2025 16:35:55 -0400 Subject: [PATCH 24/54] standardize, make modifier serializable Signed-off-by: Kyle Sayers --- examples/transform/llama3_example.py | 2 +- .../modifiers/transform/quip/base.py | 0 .../modifiers/transform/quip/template.py | 98 ------------------- .../modifiers/transform/spinquant/base.py | 13 +-- src/llmcompressor/pipelines/registry.py | 5 + 5 files changed, 13 insertions(+), 105 deletions(-) delete mode 100644 src/llmcompressor/modifiers/transform/quip/base.py delete mode 100644 src/llmcompressor/modifiers/transform/quip/template.py diff --git a/examples/transform/llama3_example.py b/examples/transform/llama3_example.py index 790619b08..876db7138 100644 --- a/examples/transform/llama3_example.py +++ b/examples/transform/llama3_example.py @@ -58,7 +58,7 @@ def tokenize(sample): # * apply spinquant transforms to model in order to make quantization easier # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = [ - SpinQuantModifier(rotations=["R1", "R2"], transform_type="random-hadamard"), + SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"), QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), ] diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/llmcompressor/modifiers/transform/quip/template.py b/src/llmcompressor/modifiers/transform/quip/template.py deleted file mode 100644 index 4ce5e47ae..000000000 --- a/src/llmcompressor/modifiers/transform/quip/template.py +++ /dev/null @@ -1,98 +0,0 @@ -from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme - -QUIP = TransformConfig( - config_groups={ - "v": TransformScheme( - type="random-hadamard", - apply=[ - TransformArgs( - targets=["Linear"], - location="input", # non-mergable - ignore="lm_head", - ), - TransformArgs( - targets=["Linear"], - location="weight_input", - inverse=True, - ignore="lm_head", - ), - ], - randomize=True, - ), - "u": TransformScheme( - type="random-hadamard", - apply=[ - TransformArgs( - targets=["Linear"], - location="weight_output", - ignore="lm_head", - ), - TransformArgs( - targets=["Linear"], - location="output", # non-mergable - inverse=True, - ignore="lm_head", - ), - ], - randomize=True, - ), - } -) - -# https://github.com/vllm-project/llm-compressor/blob/b43b27a2f277a5e62be4f8c713b84fd1c7aa116b/weight_transform.py#L24-L105 -QUIP_ONLINE = TransformConfig( - config_groups={ - "u_transform_q_o_down_proj": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=[ - "re:.*.attn.q_proj$", - "re:.*.attn.o_proj$", - "re:.*.mlp.down_proj$", - ], - location="weight_input", - ) - ], - ), - "u_transform_k_v_proj": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["re:.*.attn.k_proj$", "re:.*.attn.v_proj$"], - location="weight_input", - ) - ], - ), - "u_transform_gate_up_proj": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["re:.*.mlp.gate_proj$", "re:.*.mlp.up_proj$"], - location="weight_input", - ) - ], - ), - "v_transform_linear": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["Linear"], - location="weight_output", - ignore=["re:.*.mlp.down_proj$", "lm_head"], - inverse=True, - ) - ], - ), - "v_transform_down_proj": TransformScheme( - type="hadamard", - apply=[ - TransformArgs( - targets=["re:.*.mlp.down_proj$"], - location="weight_output", - inverse=True, - ) - ], - ), - } -) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index e448bd372..5997fac19 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -19,15 +19,15 @@ from .norm_mappings import NORM_MAPPING_REGISTRY, NormMapping -class SpinquantRotation(Enum): +class SpinquantRotation(str, Enum): R1 = "R1" R2 = "R2" R3 = "R3" R4 = "R4" -class SpinQuantModifier(Modifier): - rotations: Iterable[SpinquantRotation] = ("R1", "R2") +class SpinQuantModifier(Modifier, use_enum_values=True): + rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( default="hadamard" ) @@ -38,11 +38,12 @@ class SpinQuantModifier(Modifier): # override spinquant mappings with transform_config without overriding norms # we can combine these mappings, but it requires some more validation logic # maybe there's a reason to keep if other modifiers want norm fusing, idk - mappings: Optional[SpinQuantMappings] = None - norm_mappings: Optional[List[NormMapping]] = None + mappings: Optional[SpinQuantMappings] = Field(default=None, exclude=True) + norm_mappings: Optional[List[NormMapping]] = Field(default=None, exclude=True) # optional override for more fine-grained control - transform_config: Optional[TransformConfig] = None + # also included in recipe serialization + transform_config: Optional[TransformConfig] = Field(default=None) @field_validator("rotations", mode="before") def validate_rotations(cls, value): diff --git a/src/llmcompressor/pipelines/registry.py b/src/llmcompressor/pipelines/registry.py index 2c1a54cf5..98fb836b0 100644 --- a/src/llmcompressor/pipelines/registry.py +++ b/src/llmcompressor/pipelines/registry.py @@ -8,6 +8,7 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import SpinQuantModifier if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -60,5 +61,9 @@ def _infer_pipeline(modifiers: List[Modifier]) -> str: config = modifiers[0].resolve_quantization_config() if not config.requires_calibration_data(): return "datafree" + + # TODO: Remove hardcode + if len(modifiers) == 1 and isinstance(modifiers[0], SpinQuantModifier): + return "datafree" return "sequential" From 4b4257fe871df0f10b13e8ab9ee16f058a8456ed Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 14 Jul 2025 16:50:10 -0400 Subject: [PATCH 25/54] add compress model script Signed-off-by: Kyle Sayers --- compress_model.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 compress_model.py diff --git a/compress_model.py b/compress_model.py new file mode 100644 index 000000000..fa67bead0 --- /dev/null +++ b/compress_model.py @@ -0,0 +1,60 @@ +# python3 compress_model.py --model_id meta-llama/Llama-3.2-1B-Instruct --transform_type random-hadamard +import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import SpinQuantModifier +from llmcompressor.utils import dispatch_for_generation + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", type=str, help="Model stub to compress") + parser.add_argument("--transform_type", type=str, default=None, help="Type of transform used in SpinQuantModifier") + parser.add_argument("--scheme", type=str, default=None, help="Quantization scheme (e.g. W4A16)") + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + + # Select model and load it. + MODEL_ID = args.model_id + model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + + # Select number of samples. 512 samples is a good place to start. + # Increasing the number of samples can improve accuracy. + NUM_CALIBRATION_SAMPLES = 512 + MAX_SEQUENCE_LENGTH = 2048 + + # Configure the quantization algorithm to run. + recipe = [] + if args.transform_type: + recipe.append(SpinQuantModifier(rotations=["R1", "R2"], transform_type=args.transform_type)) + + if args.scheme: + recipe.append(QuantizationModifier(targets="Linear", scheme=args.scheme, ignore=["lm_head"])) + + # Apply algorithms. + oneshot( + model=model, + recipe=recipe, + dataset="ultrachat_200k", + splits={"calibration": f"train_sft[:{NUM_CALIBRATION_SAMPLES}]"}, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + ) + + # Confirm generations of the quantized model look sane. + print("\n\n") + print("========== SAMPLE GENERATION ==============") + dispatch_for_generation(model) + input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") + output = model.generate(input_ids, max_new_tokens=100) + print(tokenizer.decode(output[0])) + print("==========================================\n\n") + + # Save to disk compressed. + SAVE_DIR = MODEL_ID.split("/")[1] + f"-{args.transform_type}-{args.scheme}" + model.save_pretrained(SAVE_DIR, save_compressed=True) + tokenizer.save_pretrained(SAVE_DIR) From dc7ac1a1e4a94c8402f003d90eaa5a75dccabb21 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 11:08:09 -0400 Subject: [PATCH 26/54] use untie_word_embeddings Signed-off-by: Kyle Sayers --- src/llmcompressor/entrypoints/utils.py | 7 ++- .../compressed_tensors_utils.py | 50 +++++++++---------- .../test_compress_tensor_utils.py | 42 ++++++---------- 3 files changed, 43 insertions(+), 56 deletions(-) diff --git a/src/llmcompressor/entrypoints/utils.py b/src/llmcompressor/entrypoints/utils.py index 5647e4d06..95ec832fb 100644 --- a/src/llmcompressor/entrypoints/utils.py +++ b/src/llmcompressor/entrypoints/utils.py @@ -20,7 +20,7 @@ from llmcompressor.pytorch.model_load.helpers import parse_dtype from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( modify_save_pretrained, - patch_tied_tensors_bug, + untie_word_embeddings, ) from llmcompressor.transformers.utils.helpers import ( detect_last_checkpoint, @@ -61,7 +61,8 @@ def pre_process(model_args: "ModelArguments"): ) # untie tie_word_embeddings weights - patch_tied_tensors_bug(model_args.model) + if not model_args.tie_word_embeddings: + untie_word_embeddings(model_args.model) # wrap model.save_pretrained modify_save_pretrained(model_args.model) @@ -143,7 +144,6 @@ def initialize_model_from_path( cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, - tie_word_embeddings=model_args.tie_word_embeddings, trust_remote_code=model_args.trust_remote_code_model, ) @@ -156,7 +156,6 @@ def initialize_model_from_path( AutoConfig.from_pretrained( model_args.distill_teacher, use_auth_token=True if model_args.use_auth_token else None, - tie_word_embeddings=model_args.tie_word_embeddings, trust_remote_code=model_args.trust_remote_code_model, ) if model_args.distill_teacher diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 69b0e3f28..0fdaa9dc6 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -9,8 +9,9 @@ CompressionFormat, ModelCompressor, SparsityCompressionConfig, + delete_offload_parameter, is_module_offloaded, - update_offload_parameter, + register_offload_parameter, ) from loguru import logger from safetensors.torch import storage_ptr @@ -27,7 +28,7 @@ from llmcompressor.transformers.utils import RECIPE_FILE_NAME from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path -__all__ = ["modify_save_pretrained"] +__all__ = ["modify_save_pretrained", "untie_word_embeddings"] def modify_save_pretrained(model: PreTrainedModel): @@ -120,7 +121,7 @@ def save_pretrained_wrapper( model.save_pretrained = save_pretrained_compressed(model.save_pretrained) -def patch_tied_tensors_bug(model: torch.nn.Module): +def untie_word_embeddings(model: PreTrainedModel): """ Patches bug where HF transformers will fail to untie weights under specific circumstances (https://github.com/huggingface/transformers/issues/33689). @@ -129,28 +130,27 @@ def patch_tied_tensors_bug(model: torch.nn.Module): :param model: model to fix """ - if ( - hasattr(model.config, "tie_word_embeddings") - and not model.config.tie_word_embeddings - ): - input_embed = model.get_input_embeddings() - output_embed = model.get_output_embeddings() - - if input_embed is None or output_embed is None: - # some models fail to properly override the abstract methods - return - - if storage_ptr(input_embed.weight) == storage_ptr(output_embed.weight): - for module in (input_embed, output_embed): - if not is_module_offloaded(module): - # create new storage ptr for onloaded weight - untied_data = module.weight.data.clone() - module.weight.data = untied_data - else: - # create new storage ptr for offloaded weight - # note `update_offload_parameter` does not create a new storage ptr - untied_data = module._hf_hook.weights_map["weight"].clone() - update_offload_parameter(module, "weight", untied_data) + input_embed = model.get_input_embeddings() + output_embed = model.get_output_embeddings() + + for module in (input_embed, output_embed): + if module is None or not hasattr(module, "weight"): + logger.warning(f"Cannot untie {module} which does not have weight param") + continue + + # this could be replaced by a `get_offloaded_parameter` util + if not is_module_offloaded(module): + untied_data = module.weight.data.clone() + else: + untied_data = module._hf_hook.weights_map["weight"].clone() + + requires_grad = module.weight.requires_grad + new_parameter = torch.nn.Parameter(untied_data, requires_grad=requires_grad) + delete_offload_parameter(module, "weight") + register_offload_parameter(module, "weight", new_parameter) + + if hasattr(model.config, "tie_word_embeddings"): + model.config.tie_word_embeddings = False def get_model_compressor( diff --git a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py index 140e706d1..aad551ff8 100644 --- a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py @@ -28,7 +28,7 @@ from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( get_model_compressor, modify_save_pretrained, - patch_tied_tensors_bug, + untie_word_embeddings, ) from tests.testing_utils import requires_gpu @@ -224,8 +224,6 @@ def test_quant_model_reload(format, dtype, tmp_path): shutil.rmtree(tmp_path) -# technically only tie_word_embeddings=False is supported right now -# setting to True is discouraged @pytest.mark.parametrize( "offload,torch_dtype,tie_word_embeddings,device", [ @@ -237,25 +235,23 @@ def test_quant_model_reload(format, dtype, tmp_path): # offloading (True, torch.float16, False, "cpu"), (True, torch.float32, False, "cpu"), - # (True, torch.float16, True, "cpu"), # TODO: fails - # (True, torch.float32, True, "cpu"), # TODO: fails + (True, torch.float16, True, "cpu"), + (True, torch.float32, True, "cpu"), ], ) def test_model_reload(offload, torch_dtype, tie_word_embeddings, device, tmp_path): model_path = "nm-testing/llama2.c-stories15M" save_path = tmp_path / "save_path" - model = AutoModelForCausalLM.from_pretrained( - model_path, - tie_word_embeddings=tie_word_embeddings, - torch_dtype=torch_dtype, - ) + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype) if offload: model = dispatch_model(model, {"": device}, force_hooks=True) else: model = model.to(device) - patch_tied_tensors_bug(model) + if not tie_word_embeddings: + untie_word_embeddings(model) + modify_save_pretrained(model) model.save_pretrained(save_path, safe_serialization=True) @@ -294,22 +290,18 @@ def test_model_reload_gpu(offload, torch_dtype, tie_word_embeddings, device, tmp (True, torch.float32, True, "cpu"), ], ) -def test_model_shared_tensors( - offload, torch_dtype, tie_word_embeddings, device, tmp_path -): +def test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device): # load model - model = AutoModelForCausalLM.from_pretrained( - "nm-testing/llama2.c-stories15M", - torch_dtype=torch_dtype, - tie_word_embeddings=tie_word_embeddings, - ) - patch_tied_tensors_bug(model) - + model_path = "nm-testing/llama2.c-stories15M" + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype) if offload: model = dispatch_model(model, {"": device}, force_hooks=True) else: model = model.to(device) + if not tie_word_embeddings: + untie_word_embeddings(model) + # modify lm head with torch.no_grad(), align_module_device(model.lm_head): update_offload_parameter(model.lm_head, "weight", model.lm_head.weight + 1) @@ -332,12 +324,8 @@ def test_model_shared_tensors( (False, torch.float32, True, "cuda:0"), ], ) -def test_model_shared_tensors_gpu( - offload, torch_dtype, tie_word_embeddings, device, tmp_path -): - test_model_shared_tensors( - offload, torch_dtype, tie_word_embeddings, device, tmp_path - ) +def test_model_shared_tensors_gpu(offload, torch_dtype, tie_word_embeddings, device): + test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device) @requires_gpu From 8542f8d1ea21f78338f7b9ca6e1df5b49c9d8232 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 11:08:44 -0400 Subject: [PATCH 27/54] style Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/pipelines/registry.py b/src/llmcompressor/pipelines/registry.py index 98fb836b0..67d510d13 100644 --- a/src/llmcompressor/pipelines/registry.py +++ b/src/llmcompressor/pipelines/registry.py @@ -61,7 +61,7 @@ def _infer_pipeline(modifiers: List[Modifier]) -> str: config = modifiers[0].resolve_quantization_config() if not config.requires_calibration_data(): return "datafree" - + # TODO: Remove hardcode if len(modifiers) == 1 and isinstance(modifiers[0], SpinQuantModifier): return "datafree" From b1e637eb88f0b9d8c5524a836d99c0baade0a54f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 12:46:13 -0400 Subject: [PATCH 28/54] better registery logic Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/base.py | 12 +++++------ .../modifiers/transform/spinquant/mappings.py | 21 ++++++++++++++++--- .../transform/spinquant/norm_mappings.py | 19 +++++++++++++++-- .../compressed_tensors_utils.py | 1 - 4 files changed, 40 insertions(+), 13 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 5997fac19..c8376a6a0 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -15,8 +15,8 @@ from llmcompressor.modeling import fuse_norm_linears, normalize_embedding from llmcompressor.modifiers import Modifier -from .mappings import SPINQUANT_MAPPING_REGISTRY, SpinQuantMappings -from .norm_mappings import NORM_MAPPING_REGISTRY, NormMapping +from .mappings import SpinQuantMapping, infer_mapping_from_model +from .norm_mappings import NormMapping, infer_norm_mapping_from_model class SpinquantRotation(str, Enum): @@ -36,9 +36,7 @@ class SpinQuantModifier(Modifier, use_enum_values=True): # norm mappings separate from spinquant mappings to allow users to # override spinquant mappings with transform_config without overriding norms - # we can combine these mappings, but it requires some more validation logic - # maybe there's a reason to keep if other modifiers want norm fusing, idk - mappings: Optional[SpinQuantMappings] = Field(default=None, exclude=True) + mappings: Optional[SpinQuantMapping] = Field(default=None, exclude=True) norm_mappings: Optional[List[NormMapping]] = Field(default=None, exclude=True) # optional override for more fine-grained control @@ -53,8 +51,8 @@ def validate_rotations(cls, value): def on_initialize(self, state: State, **kwargs) -> bool: # TODO: more validation - self.mappings = SPINQUANT_MAPPING_REGISTRY[state.model.__class__.__name__] - self.norm_mappings = NORM_MAPPING_REGISTRY[state.model.__class__.__name__] + self.mappings = infer_mapping_from_model(state.model) + self.norm_mappings = infer_norm_mapping_from_model(state.model) if self.transform_config is not None: if self.mappings is not None: diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py index acf692d22..7dc327b78 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -1,9 +1,13 @@ from typing import Dict, List, Optional +from loguru import logger from pydantic import BaseModel, Field, field_validator +from transformers import PreTrainedModel +__all__ = ["SpinQuantMapping", "infer_mapping_from_model"] -class SpinQuantMappings(BaseModel): + +class SpinQuantMapping(BaseModel): embedding: str attn_q: str @@ -25,7 +29,7 @@ def cast_to_list(cls, value): return value -_default_mappings = SpinQuantMappings( +_default_mappings = SpinQuantMapping( embedding="re:.*embed_tokens$", attn_q="re:.*q_proj$", attn_k="re:.*k_proj$", @@ -37,6 +41,17 @@ def cast_to_list(cls, value): ) -SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMappings] = { +SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMapping] = { "LlamaForCausalLM": _default_mappings, } + + +def infer_mapping_from_model(model: PreTrainedModel) -> SpinQuantMapping: + architecture = model.__class__.__name__ + if architecture not in SPINQUANT_MAPPING_REGISTRY: + logger.info( + f"Unrecognized model architecture {architecture}. " + "Falling back to default mappings" + ) + + return SPINQUANT_MAPPING_REGISTRY.get(architecture, _default_mappings) diff --git a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py index cefb987ca..0752f6986 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py @@ -1,6 +1,10 @@ from typing import Dict, List +from loguru import logger from pydantic import BaseModel, field_validator +from transformers import PreTrainedModel + +__all__ = ["infer_norm_mapping_from_model"] class NormMapping(BaseModel): @@ -15,7 +19,7 @@ def cast_to_list(cls, value): return value -_default_norm_mappings = [ +_default_mappings = [ NormMapping( norm="re:.*input_layernorm$", linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], @@ -31,5 +35,16 @@ def cast_to_list(cls, value): ] NORM_MAPPING_REGISTRY: Dict[str, NormMapping] = { - "LlamaForCausalLM": _default_norm_mappings, + "LlamaForCausalLM": _default_mappings, } + + +def infer_norm_mapping_from_model(model: PreTrainedModel) -> List[NormMapping]: + architecture = model.__class__.__name__ + if architecture not in NORM_MAPPING_REGISTRY: + logger.info( + f"Unrecognized model architecture {architecture}. " + "Falling back to default mappings" + ) + + return NORM_MAPPING_REGISTRY.get(architecture, _default_mappings) diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 0fdaa9dc6..1495f6d06 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -14,7 +14,6 @@ register_offload_parameter, ) from loguru import logger -from safetensors.torch import storage_ptr from transformers import PreTrainedModel from llmcompressor.core import active_session From b44ac817b65dec264146c849d67566de5b38cc37 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 13:05:05 -0400 Subject: [PATCH 29/54] remove dummy model test (add later) Signed-off-by: Kyle Sayers --- .../modifiers/transform/test_dummy_model.py | 107 ------------------ 1 file changed, 107 deletions(-) delete mode 100644 tests/llmcompressor/modifiers/transform/test_dummy_model.py diff --git a/tests/llmcompressor/modifiers/transform/test_dummy_model.py b/tests/llmcompressor/modifiers/transform/test_dummy_model.py deleted file mode 100644 index 020a61e99..000000000 --- a/tests/llmcompressor/modifiers/transform/test_dummy_model.py +++ /dev/null @@ -1,107 +0,0 @@ -import torch -from compressed_tensors.utils import update_parameter_data -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.models.llama.modeling_llama import LlamaRMSNorm - -from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier -from llmcompressor.modifiers.transform import SpinQuantModifier -from llmcompressor.utils import dispatch_for_generation - -hidden_dim = intermediate_dim = 64 -up_dim = 128 -num_embeddings = 12 - - -class DummySelfAttn(torch.nn.Module): - def __init__(self, hidden_dim, intermediate_dim): - super().__init__() - self.q_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=None) - self.k_proj = torch.nn.Linear(hidden_dim, intermediate_dim, bias=None) - self.v_proj = torch.nn.Linear(hidden_dim, intermediate_dim, bias=None) - self.o_proj = torch.nn.Linear(hidden_dim, hidden_dim, bias=None) - self.num_heads = 1 - self.num_key_value_groups = 1 - - def forward(self, hidden_states): - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - ### EAGER ATTENTION - attn_weights = torch.matmul(q.T, k) - - attn_weights = torch.nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32 - ).to(q.dtype) - attn_output = torch.matmul(attn_weights, v.T) - attn_output = attn_output.T.contiguous() - - return self.o_proj(attn_output) - - -class DummyMLP(torch.nn.Module): - def __init__(self, hidden_dim, up_dim): - super().__init__() - self.up_proj = torch.nn.Linear(hidden_dim, up_dim, bias=None) - self.gate_proj = torch.nn.Linear(hidden_dim, up_dim, bias=None) - self.down_proj = torch.nn.Linear(up_dim, hidden_dim, bias=None) - self.act_fn = torch.nn.SiLU() - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class DummyModel(torch.nn.Module): - def __init__(self, num_embeddings, hidden_dim, intermediate_dim, up_dim): - super().__init__() - self.embed_tokens = torch.nn.Embedding(num_embeddings, hidden_dim) - self.input_layernorm = LlamaRMSNorm(hidden_dim) - self.post_attention_layernorm = LlamaRMSNorm(hidden_dim) - self.self_attn = DummySelfAttn(hidden_dim, intermediate_dim) - self.mlp = DummyMLP(hidden_dim, up_dim) - self.lm_head = torch.nn.Linear(hidden_dim, num_embeddings, bias=None) - - def forward(self, input_ids): - x = self.embed_tokens(input_ids) - x = self.input_layernorm(x) - x = self.self_attn(x) - x = self.post_attention_layernorm(x) - x = self.mlp(x) - return self.lm_head(x) - - -def test_dummy_model(): - model = DummyModel(num_embeddings, hidden_dim, intermediate_dim, up_dim) - - # TODO Uncomment this to see norm diff > 1e-6 - # This is due to issue Kyle spotted in https://arxiv.org/pdf/2405.16406 Page 5 Footnote 2 - # Will have to fuse layernorms with subsequent layers so that input_layernorm.weight is equal to torch.ones() (this apparently makes it rotation invariant) - # https://github.com/facebookresearch/SpinQuant/blob/8f47aa3f00e8662caf1a484153920a07e5281c3a/utils/fuse_norm_utils.py#L39 - # update_parameter_data( - # model.input_layernorm, - # torch.rand(model.input_layernorm.weight.shape), - # "weight", - # ) - - input_ids = torch.IntTensor([1, 2, 3, 4, 5]) - orig_output = model(input_ids) - - recipe = [ - SpinQuantModifier(rotations=["R1", "R2"]), - ] - - # TODO: work around preprocessing? - oneshot( - model=model, - recipe=recipe, - pipeline="datafree", - log_dir=None, - ) - - # # Confirm generations of the quantized model look the same - transformed_output = model(input_ids) - - print(f"Norm Diff {(orig_output-transformed_output).norm()}") - print(f"Norm {orig_output.norm()}, {transformed_output.norm()}") From 7a52b710b73682119c45f61669f92b5ac6e0b189 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 13:34:18 -0400 Subject: [PATCH 30/54] docstring Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/base.py | 49 ++++++++++++++++--- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index c8376a6a0..5a1ea7844 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -8,7 +8,7 @@ TransformScheme, apply_transform_config, ) -from pydantic import Field, field_validator +from pydantic import Field, ValidationInfo, field_validator from transformers import PreTrainedModel from llmcompressor.core import Event, EventType, State @@ -27,6 +27,37 @@ class SpinquantRotation(str, Enum): class SpinQuantModifier(Modifier, use_enum_values=True): + """ + Implements the transforms according to + [SpinQuant: LLM quantization with learned rotations](https://arxiv.org/abs/2405.16406) # noqa: E501 + + Transforms (rotations) are extra layers added to a model which reduce the accuracy + loss induced by quantization. This is achived through "rotating" weights and + activations into a space with a smaller dynamic range of values, thus decreasing + the range of scales required for quantization. + + The SpinQuant authors describe four different rotations which can be applied to a + model. R1 and R2 are "offline" rotations, meaning that they can be fused into + existing weights and therefore do not induce runtime cost. R3 and R4 are "online" + rotations, meaning that they require additional computation at runtime. + + :param rotations: A list containing the names of rotations to apply to the model. + Possible rotations include R1, R2, R3, and R4 + :param transform_type: The type of transform to apply to the model. + `"hadamard"` has the least performance cost but only supports sizes which are + powers of power of two. + `"random-matrix"` has more performance cost, but supports a much larger set of + sizes. + `"random-matrix"` has the greatest performance cost, but supports any size + :param randomize: if True, create distinct transforms for each application + :param learnable: if True, attach gradients to transform weights for training + :param mappings: Specifies layers within a model to target for transforms. + A mapping will be inferred if None is provided + :param norm_mappings: Specifies layers within a model to target for norm fusing. + A mapping will be inferred if None is provided + :param transform_config: Optional transform config which overrides `mappings` + """ + rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( default="hadamard" @@ -43,6 +74,10 @@ class SpinQuantModifier(Modifier, use_enum_values=True): # also included in recipe serialization transform_config: Optional[TransformConfig] = Field(default=None) + @field_validator("randomize", "learnable", mode="before") + def validate_not_implemented(cls, value, info: ValidationInfo): + raise NotImplementedError(f"{info.field_name} is not supported right now") + @field_validator("rotations", mode="before") def validate_rotations(cls, value): if isinstance(value, Iterable): @@ -50,16 +85,18 @@ def validate_rotations(cls, value): return value def on_initialize(self, state: State, **kwargs) -> bool: - # TODO: more validation - self.mappings = infer_mapping_from_model(state.model) - self.norm_mappings = infer_norm_mapping_from_model(state.model) - if self.transform_config is not None: if self.mappings is not None: - raise ValueError() + raise ValueError( + "Please provide either `transform_config` or `mappings` " + "but not both" + ) return True + self.mappings = infer_mapping_from_model(state.model) + self.norm_mappings = infer_norm_mapping_from_model(state.model) + config_groups = {} if SpinquantRotation.R1 in self.rotations: config_groups["R1"] = self._create_r1_scheme() From f4d7ec6d807c629a264cc90b3fec13d1b281e242 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 15:02:11 -0400 Subject: [PATCH 31/54] update docstring Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/base.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 5a1ea7844..2bf593635 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -55,24 +55,34 @@ class SpinQuantModifier(Modifier, use_enum_values=True): A mapping will be inferred if None is provided :param norm_mappings: Specifies layers within a model to target for norm fusing. A mapping will be inferred if None is provided - :param transform_config: Optional transform config which overrides `mappings` + :param transform_config: Optional transform config for overriding provided arguments """ - rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) + rotations: List[SpinquantRotation] = Field( + default_factory=lambda: ["R1", "R2"], exclude=True + ) transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( - default="hadamard" + default="hadamard", exclude=True ) - randomize: bool = Field(default=False) - learnable: bool = Field(default=False) + randomize: bool = Field(default=False, exclude=True) + learnable: bool = Field(default=False, exclude=True) # norm mappings separate from spinquant mappings to allow users to # override spinquant mappings with transform_config without overriding norms - mappings: Optional[SpinQuantMapping] = Field(default=None, exclude=True) - norm_mappings: Optional[List[NormMapping]] = Field(default=None, exclude=True) + mappings: Optional[SpinQuantMapping] = Field( + default=None, + repr=False, + exclude=True, + ) + norm_mappings: Optional[List[NormMapping]] = Field( + default=None, + repr=False, + exclude=True, + ) # optional override for more fine-grained control # also included in recipe serialization - transform_config: Optional[TransformConfig] = Field(default=None) + transform_config: Optional[TransformConfig] = Field(default=None, repr=False) @field_validator("randomize", "learnable", mode="before") def validate_not_implemented(cls, value, info: ValidationInfo): @@ -86,12 +96,6 @@ def validate_rotations(cls, value): def on_initialize(self, state: State, **kwargs) -> bool: if self.transform_config is not None: - if self.mappings is not None: - raise ValueError( - "Please provide either `transform_config` or `mappings` " - "but not both" - ) - return True self.mappings = infer_mapping_from_model(state.model) From f18d0e894d984d6ec9207f9fe71e6533669c8aa3 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 15:07:27 -0400 Subject: [PATCH 32/54] rename example file Signed-off-by: Kyle Sayers --- examples/transform/{llama3_example.py => spinquant_example.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/transform/{llama3_example.py => spinquant_example.py} (100%) diff --git a/examples/transform/llama3_example.py b/examples/transform/spinquant_example.py similarity index 100% rename from examples/transform/llama3_example.py rename to examples/transform/spinquant_example.py From cec2914342ad337be13fccff29ca7426d713c0ec Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Jul 2025 11:13:48 -0400 Subject: [PATCH 33/54] use match_modules_set Signed-off-by: Kyle Sayers --- .../modifiers/transform/spinquant/base.py | 24 ++++--------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 2bf593635..5978b93ea 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Iterable, List, Literal, Optional -from compressed_tensors import is_match, match_named_modules +from compressed_tensors import match_modules_set, match_named_modules from compressed_tensors.transform import ( TransformArgs, TransformConfig, @@ -156,24 +156,10 @@ def _prenormalize_embeddings(self, model: PreTrainedModel): def _fuse_norms(self, model: PreTrainedModel): for mapping in self.norm_mappings: - targets = (mapping.norm, *mapping.linears) - matches = dict() - - for name, module in model.named_modules(): - # match until we get a full set - for target in targets: - if is_match(name, module, target): - if target in matches: - raise ValueError("Cannot match twice") - matches[target] = module - - # once we have a full set, fuse and reset - if all(target in matches for target in targets): - fuse_norm_linears( - matches[mapping.norm], - (matches[target] for target in mapping.linears), - ) - matches = dict() + for norm, *linears in match_modules_set( + model, (mapping.norm, *mapping.linears) + ): + fuse_norm_linears(norm, linears) def _create_r1_scheme(self) -> TransformScheme: return TransformScheme( From 0c5c514313d887caf715aa9f14bdb35f50e3bad6 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 17 Jul 2025 14:12:55 -0400 Subject: [PATCH 34/54] unit test fixes Signed-off-by: Brian Dellabetta --- src/llmcompressor/modeling/fuse.py | 3 ++- tests/llmcompressor/modeling/test_fuse.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index 40dc31e6a..e59be596c 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -32,6 +32,7 @@ def normalize_embedding(embedding: torch.nn.Module): else: raise ValueError(f"Cannot normalize embedding of type {type(embedding)}") + def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]): """ Fuse a norm layer into subsequent linear layers. This useful for ensuring transform @@ -42,7 +43,7 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]) :param norm: norm layer whose weight will be fused into subsequent linears :param linears: linear layers which directly follow the norm layer """ - if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm)): + if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm, torch.nn.LayerNorm)): for linear in linears: # NOTE: spinquant does this op in float64 exec_device = get_execution_device(norm) diff --git a/tests/llmcompressor/modeling/test_fuse.py b/tests/llmcompressor/modeling/test_fuse.py index 005d89f99..f85cd68dc 100644 --- a/tests/llmcompressor/modeling/test_fuse.py +++ b/tests/llmcompressor/modeling/test_fuse.py @@ -1,13 +1,13 @@ import pytest import torch -from llmcompressor.modeling.fuse import center_embeddings, fuse_norm_linears +from llmcompressor.modeling.fuse import normalize_embedding, fuse_norm_linears @pytest.mark.unit -def test_center_embeddings(): +def test_normalize_embedding(): embedding = torch.nn.Embedding(10, 10) - center_embeddings(embedding) + normalize_embedding(embedding) assert torch.allclose( embedding.weight.mean(dim=1), torch.zeros(embedding.num_embeddings), atol=1e-5 From f2ef7cfd5734434b285c44ac43c8f108ba9afae1 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 17 Jul 2025 14:16:26 -0400 Subject: [PATCH 35/54] style fixes Signed-off-by: Brian Dellabetta --- tests/llmcompressor/modeling/test_fuse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/llmcompressor/modeling/test_fuse.py b/tests/llmcompressor/modeling/test_fuse.py index f85cd68dc..5798f692c 100644 --- a/tests/llmcompressor/modeling/test_fuse.py +++ b/tests/llmcompressor/modeling/test_fuse.py @@ -1,7 +1,7 @@ import pytest import torch -from llmcompressor.modeling.fuse import normalize_embedding, fuse_norm_linears +from llmcompressor.modeling.fuse import fuse_norm_linears, normalize_embedding @pytest.mark.unit From d0e5bc5816f5ce405a180bd4635b789f0e356ede Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 24 Jul 2025 21:27:35 +0000 Subject: [PATCH 36/54] remove hardcoded pipeline logic Signed-off-by: Brian Dellabetta --- src/llmcompressor/pipelines/registry.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/llmcompressor/pipelines/registry.py b/src/llmcompressor/pipelines/registry.py index 67d510d13..2c1a54cf5 100644 --- a/src/llmcompressor/pipelines/registry.py +++ b/src/llmcompressor/pipelines/registry.py @@ -8,7 +8,6 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.quantization import QuantizationModifier -from llmcompressor.modifiers.transform import SpinQuantModifier if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -62,8 +61,4 @@ def _infer_pipeline(modifiers: List[Modifier]) -> str: if not config.requires_calibration_data(): return "datafree" - # TODO: Remove hardcode - if len(modifiers) == 1 and isinstance(modifiers[0], SpinQuantModifier): - return "datafree" - return "sequential" From 31ac8e95e6426ee4056254d551310ee2f201d04f Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 24 Jul 2025 22:15:49 +0000 Subject: [PATCH 37/54] docstrings Signed-off-by: Brian Dellabetta --- .../modifiers/transform/spinquant/base.py | 14 ++++++++++++++ .../modifiers/transform/spinquant/mappings.py | 19 +++++++++++++++++++ .../transform/spinquant/norm_mappings.py | 10 ++++++++++ 3 files changed, 43 insertions(+) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 5978b93ea..23bf604b0 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -41,6 +41,20 @@ class SpinQuantModifier(Modifier, use_enum_values=True): existing weights and therefore do not induce runtime cost. R3 and R4 are "online" rotations, meaning that they require additional computation at runtime. + Lifecycle: + - on_initialize + - infer SpinQuantMappings & NormMappings + - as needed, create transform schemes for R1, R2, R3, & R4 + - on_start + - normalize embeddings + - fuse norm layers into subsequent Linear layers + - apply TransformConfig + - fuse transforms into weights for mergeable transforms + - add hooks for online transforms + - on sequential epoch end + - on_end + - on_finalize + :param rotations: A list containing the names of rotations to apply to the model. Possible rotations include R1, R2, R3, and R4 :param transform_type: The type of transform to apply to the model. diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py index 7dc327b78..514d1f109 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -8,6 +8,25 @@ class SpinQuantMapping(BaseModel): + """ + SpinQuant needs to know the entire architecture of the model, + as R1, R2, R3, and R4 rotations need to be applied to specific + layers (https://arxiv.org/pdf/2405.16406 Fig. 1). + + :param embedding: name or regex of embedding layer + :param attn_q: name or regex of q_proj layer in attention block + :param attn_k: name or regex of k_proj layer in attention block + :param attn_v: name or regex of v_proj layer in attention block + :param attn_o: name or regex of o_proj layer in attention block + :param attn_head_dim: head_dim of the attention module, needed + because R2 needs to be applied "head-wisely" to v_proj and + o_proj + :param mlp_in: list of names or regexes for the mlp blocks that + receive the input to the MLP block, usually up_proj and gate_proj + :param mlp_out: list of names or regexes for the mlp blocks that + consitute the output of the MLP block, usually down_proj + """ + embedding: str attn_q: str diff --git a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py index 0752f6986..896b0db3d 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py @@ -8,6 +8,16 @@ class NormMapping(BaseModel): + """ + SpinQuant needs to know where every norm layer exists in the model, + as well as all the subsequent Linear layers the norm passes into. + This is because the norm layer weights need to normalized before + transforms can be fused into Linear layers. + + :param norm: name or regex that matches norm layer in model + :param linears: list of names or regexes of Linear layers that + receive input from norm. + """ norm: str linears: List[str] From a4abb3d31bcd2cf067728b2f12f2ba06d53503f6 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 24 Jul 2025 22:26:51 +0000 Subject: [PATCH 38/54] stylefixes Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/transform/spinquant/base.py | 4 ++-- .../modifiers/transform/spinquant/norm_mappings.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 23bf604b0..b9a18e961 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -28,8 +28,8 @@ class SpinquantRotation(str, Enum): class SpinQuantModifier(Modifier, use_enum_values=True): """ - Implements the transforms according to - [SpinQuant: LLM quantization with learned rotations](https://arxiv.org/abs/2405.16406) # noqa: E501 + Implements the transforms according to "SpinQuant: LLM quantization + with learned rotations" (https://arxiv.org/abs/2405.16406) Transforms (rotations) are extra layers added to a model which reduce the accuracy loss induced by quantization. This is achived through "rotating" weights and diff --git a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py index 896b0db3d..e60ac0d1a 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py @@ -18,6 +18,7 @@ class NormMapping(BaseModel): :param linears: list of names or regexes of Linear layers that receive input from norm. """ + norm: str linears: List[str] From 490b9875584849efaf6f0ce45d122e86137e8a1a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 14:56:53 -0400 Subject: [PATCH 39/54] implement quip Signed-off-by: Kyle Sayers --- .../modifiers/transform/__init__.py | 1 + .../modifiers/transform/quip/__init__.py | 3 + .../modifiers/transform/quip/base.py | 131 ++++++++++++++++++ 3 files changed, 135 insertions(+) create mode 100644 src/llmcompressor/modifiers/transform/quip/__init__.py create mode 100644 src/llmcompressor/modifiers/transform/quip/base.py diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index 9956d0340..4e71c3f03 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa from .spinquant import SpinQuantModifier +from .quip import QuIPModifier \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/quip/__init__.py b/src/llmcompressor/modifiers/transform/quip/__init__.py new file mode 100644 index 000000000..8bdc93d14 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/quip/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base import * diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py new file mode 100644 index 000000000..a718c0923 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -0,0 +1,131 @@ +from typing import Iterable, List, Literal, Optional, Union + +from compressed_tensors.transform import ( + TransformArgs, + TransformConfig, + TransformScheme, + apply_transform_config, +) +from pydantic import Field, ValidationInfo, field_validator + +from llmcompressor.core import Event, EventType, State +from llmcompressor.modifiers import Modifier + +__all__ = ["QuIPModifier"] + + +class QuIPModifier(Modifier): + """ + Implements the transforms according to + [QuIP#: Even Better LLM Quantization with Hadamard Incoherence and Lattice Codebooks](https://arxiv.org/pdf/2402.04396) # noqa: E501 + [QuIP: 2-Bit Quantization of Large Language Models With Guarantees](https://arxiv.org/abs/2307.13304) # noqa: E501 + + Transforms (rotations) are extra layers added to a model which reduce the accuracy + loss induced by quantization. This is achived through "rotating" weights and + activations into a space with a smaller dynamic range of values, thus decreasing + the range of scales required for quantization. + + QuIP and QuIP# apply transforms to every linear layer, two of which are fused into + the model weights and two of which remain as online rotations computed at runtime. + + :param transform_type: The type of transform to apply to the model. + `"hadamard"` has the least performance cost but only supports sizes which are + powers of power of two. + `"random-matrix"` has more performance cost, but supports a much larger set of + sizes. + `"random-matrix"` has the greatest performance cost, but supports any size + :param randomize: If true, create distinct transforms for each application + :param learnable: If true, attach gradients to transform weights for training + :param ignore: Modules to ignore when attaching transforms + :param transform_config: Optional transform config for overriding provided arguments + """ + + transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( + default="hadamard", exclude=True + ) + randomize: bool = Field(default=False, exclude=True) + learnable: bool = Field(default=False, exclude=True) + ignore: Union[str, List[str]] = Field(default="lm_head", exclude=True) + + # optional override for more fine-grained control + # also included in recipe serialization + transform_config: Optional[TransformConfig] = Field(default=None, repr=False) + + @field_validator("randomize", "learnable", mode="before") + def validate_not_implemented(cls, value, info: ValidationInfo): + raise NotImplementedError(f"{info.field_name} is not supported right now") + + def on_initialize(self, state: State, **kwargs) -> bool: + if self.transform_config is not None: + return True + + self.transform_config = self._create_config() + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + apply_transform_config(state.model, self.transform_config) + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + pass + + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + + def on_finalize(self, state: State, **kwargs) -> bool: + if not self.ended_: + self.on_end(state, None) + + return True + + def _create_config(self) -> TransformConfig: + return TransformConfig( + config_groups={ + "v": TransformScheme( + type=self.transform_type, + apply=[ + TransformArgs( + targets=["Linear"], + location="input", # non-mergable + ignore=self.ignore, + ), + TransformArgs( + targets=["Linear"], + location="weight_input", + inverse=True, + ignore=self.ignore, + ), + ], + randomize=self.randomize, + requires_grad=self.learnable, + ), + "u": TransformScheme( + type=self.transform_type, + apply=[ + TransformArgs( + targets=["Linear"], + location="weight_output", + ignore=self.ignore, + ), + TransformArgs( + targets=["Linear"], + location="output", # non-mergable + inverse=True, + ignore=self.ignore, + ), + ], + randomize=self.randomize, + requires_grad=self.learnable, + ), + } + ) From ac7dbcd6c4eba694572f9944cabf9b038682c74d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 15 Jul 2025 15:03:02 -0400 Subject: [PATCH 40/54] add example, cleanup Signed-off-by: Kyle Sayers --- examples/transform/quip_example.py | 87 +++++++++++++++++++ .../modifiers/transform/__init__.py | 2 +- .../modifiers/transform/quip/base.py | 4 +- 3 files changed, 90 insertions(+), 3 deletions(-) create mode 100644 examples/transform/quip_example.py diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py new file mode 100644 index 000000000..e5f7faea0 --- /dev/null +++ b/examples/transform/quip_example.py @@ -0,0 +1,87 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import QuIPModifier +from llmcompressor.utils import dispatch_for_generation + +# Select model and load it. +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype="auto", +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * apply spinquant transforms to model in order to make quantization easier +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = [ + QuIPModifier(transform_type="random-hadamard"), + QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), +] + +# Apply algorithms. +oneshot( + model=model, + recipe=recipe, + dataset=ds, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + pipeline="basic", +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-transformed-w4a16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index 4e71c3f03..eaa714183 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa +from .quip import QuIPModifier from .spinquant import SpinQuantModifier -from .quip import QuIPModifier \ No newline at end of file diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py index a718c0923..8c86a1471 100644 --- a/src/llmcompressor/modifiers/transform/quip/base.py +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union from compressed_tensors.transform import ( TransformArgs, @@ -24,7 +24,7 @@ class QuIPModifier(Modifier): loss induced by quantization. This is achived through "rotating" weights and activations into a space with a smaller dynamic range of values, thus decreasing the range of scales required for quantization. - + QuIP and QuIP# apply transforms to every linear layer, two of which are fused into the model weights and two of which remain as online rotations computed at runtime. From a5d3ddc6a87d23755e3d89ffd58c8d544f2a50b5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 1 Aug 2025 19:37:16 -0400 Subject: [PATCH 41/54] update quip example Signed-off-by: Kyle Sayers --- examples/transform/quip_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py index e5f7faea0..5be0d20e1 100644 --- a/examples/transform/quip_example.py +++ b/examples/transform/quip_example.py @@ -69,7 +69,7 @@ def tokenize(sample): dataset=ds, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, - pipeline="basic", + pipeline="datafree", ) # Confirm generations of the quantized model look sane. @@ -82,6 +82,6 @@ def tokenize(sample): print("==========================================\n\n") # Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-transformed-w4a16" +SAVE_DIR = MODEL_ID.split("/")[1] + "-quip-w4a16" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) From a21648db8ada5b11738de9b8163ce3369fff5462 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 4 Aug 2025 12:37:06 -0400 Subject: [PATCH 42/54] prepare for merge without spinquant Signed-off-by: Kyle Sayers --- examples/transform/quip_example.py | 2 +- examples/transform/spinquant_example.py | 86 ------- .../modifiers/transform/__init__.py | 1 - .../modifiers/transform/quip/base.py | 10 +- .../modifiers/transform/spinquant/__init__.py | 3 - .../modifiers/transform/spinquant/base.py | 235 ------------------ .../modifiers/transform/spinquant/mappings.py | 76 ------ .../transform/spinquant/norm_mappings.py | 61 ----- .../modifiers/transform/test_correctness.py | 25 +- 9 files changed, 21 insertions(+), 478 deletions(-) delete mode 100644 examples/transform/spinquant_example.py delete mode 100644 src/llmcompressor/modifiers/transform/spinquant/__init__.py delete mode 100644 src/llmcompressor/modifiers/transform/spinquant/base.py delete mode 100644 src/llmcompressor/modifiers/transform/spinquant/mappings.py delete mode 100644 src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py index 5be0d20e1..26f76f4ec 100644 --- a/examples/transform/quip_example.py +++ b/examples/transform/quip_example.py @@ -7,7 +7,7 @@ from llmcompressor.utils import dispatch_for_generation # Select model and load it. -MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" model = AutoModelForCausalLM.from_pretrained( MODEL_ID, diff --git a/examples/transform/spinquant_example.py b/examples/transform/spinquant_example.py deleted file mode 100644 index 876db7138..000000000 --- a/examples/transform/spinquant_example.py +++ /dev/null @@ -1,86 +0,0 @@ -from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer - -from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import QuantizationModifier -from llmcompressor.modifiers.transform import SpinQuantModifier -from llmcompressor.utils import dispatch_for_generation - -# Select model and load it. -MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" - -model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, - torch_dtype="auto", -) -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - -# Select calibration dataset. -DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" - -# Select number of samples. 512 samples is a good place to start. -# Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 2048 - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") -ds = ds.shuffle(seed=42) - - -def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["messages"], - tokenize=False, - ) - } - - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - add_special_tokens=False, - ) - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - -# Configure the quantization algorithm to run. -# * apply spinquant transforms to model in order to make quantization easier -# * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = [ - SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"), - QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), -] - -# Apply algorithms. -oneshot( - model=model, - recipe=recipe, - dataset=ds, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, -) - -# Confirm generations of the quantized model look sane. -print("\n\n") -print("========== SAMPLE GENERATION ==============") -dispatch_for_generation(model) -input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=100) -print(tokenizer.decode(output[0])) -print("==========================================\n\n") - -# Save to disk compressed. -SAVE_DIR = MODEL_ID.split("/")[1] + "-transformed-w4a16" -model.save_pretrained(SAVE_DIR, save_compressed=True) -tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index eaa714183..e37e388f4 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,4 +1,3 @@ # flake8: noqa from .quip import QuIPModifier -from .spinquant import SpinQuantModifier diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py index 8c86a1471..d34245055 100644 --- a/src/llmcompressor/modifiers/transform/quip/base.py +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -21,7 +21,7 @@ class QuIPModifier(Modifier): [QuIP: 2-Bit Quantization of Large Language Models With Guarantees](https://arxiv.org/abs/2307.13304) # noqa: E501 Transforms (rotations) are extra layers added to a model which reduce the accuracy - loss induced by quantization. This is achived through "rotating" weights and + loss induced by quantization. This is achieved through "rotating" weights and activations into a space with a smaller dynamic range of values, thus decreasing the range of scales required for quantization. @@ -31,7 +31,7 @@ class QuIPModifier(Modifier): :param transform_type: The type of transform to apply to the model. `"hadamard"` has the least performance cost but only supports sizes which are powers of power of two. - `"random-matrix"` has more performance cost, but supports a much larger set of + `"random-hadamard"` has more performance cost, but supports a much larger set of sizes. `"random-matrix"` has the greatest performance cost, but supports any size :param randomize: If true, create distinct transforms for each application @@ -53,7 +53,9 @@ class QuIPModifier(Modifier): @field_validator("randomize", "learnable", mode="before") def validate_not_implemented(cls, value, info: ValidationInfo): - raise NotImplementedError(f"{info.field_name} is not supported right now") + if value: + raise NotImplementedError(f"{info.field_name} is not supported right now") + return value def on_initialize(self, state: State, **kwargs) -> bool: if self.transform_config is not None: @@ -102,6 +104,7 @@ def _create_config(self) -> TransformConfig: TransformArgs( targets=["Linear"], location="weight_input", + # location="input", inverse=True, ignore=self.ignore, ), @@ -115,6 +118,7 @@ def _create_config(self) -> TransformConfig: TransformArgs( targets=["Linear"], location="weight_output", + # location="output", ignore=self.ignore, ), TransformArgs( diff --git a/src/llmcompressor/modifiers/transform/spinquant/__init__.py b/src/llmcompressor/modifiers/transform/spinquant/__init__.py deleted file mode 100644 index 8bdc93d14..000000000 --- a/src/llmcompressor/modifiers/transform/spinquant/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# flake8: noqa - -from .base import * diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py deleted file mode 100644 index b9a18e961..000000000 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ /dev/null @@ -1,235 +0,0 @@ -from enum import Enum -from typing import Iterable, List, Literal, Optional - -from compressed_tensors import match_modules_set, match_named_modules -from compressed_tensors.transform import ( - TransformArgs, - TransformConfig, - TransformScheme, - apply_transform_config, -) -from pydantic import Field, ValidationInfo, field_validator -from transformers import PreTrainedModel - -from llmcompressor.core import Event, EventType, State -from llmcompressor.modeling import fuse_norm_linears, normalize_embedding -from llmcompressor.modifiers import Modifier - -from .mappings import SpinQuantMapping, infer_mapping_from_model -from .norm_mappings import NormMapping, infer_norm_mapping_from_model - - -class SpinquantRotation(str, Enum): - R1 = "R1" - R2 = "R2" - R3 = "R3" - R4 = "R4" - - -class SpinQuantModifier(Modifier, use_enum_values=True): - """ - Implements the transforms according to "SpinQuant: LLM quantization - with learned rotations" (https://arxiv.org/abs/2405.16406) - - Transforms (rotations) are extra layers added to a model which reduce the accuracy - loss induced by quantization. This is achived through "rotating" weights and - activations into a space with a smaller dynamic range of values, thus decreasing - the range of scales required for quantization. - - The SpinQuant authors describe four different rotations which can be applied to a - model. R1 and R2 are "offline" rotations, meaning that they can be fused into - existing weights and therefore do not induce runtime cost. R3 and R4 are "online" - rotations, meaning that they require additional computation at runtime. - - Lifecycle: - - on_initialize - - infer SpinQuantMappings & NormMappings - - as needed, create transform schemes for R1, R2, R3, & R4 - - on_start - - normalize embeddings - - fuse norm layers into subsequent Linear layers - - apply TransformConfig - - fuse transforms into weights for mergeable transforms - - add hooks for online transforms - - on sequential epoch end - - on_end - - on_finalize - - :param rotations: A list containing the names of rotations to apply to the model. - Possible rotations include R1, R2, R3, and R4 - :param transform_type: The type of transform to apply to the model. - `"hadamard"` has the least performance cost but only supports sizes which are - powers of power of two. - `"random-matrix"` has more performance cost, but supports a much larger set of - sizes. - `"random-matrix"` has the greatest performance cost, but supports any size - :param randomize: if True, create distinct transforms for each application - :param learnable: if True, attach gradients to transform weights for training - :param mappings: Specifies layers within a model to target for transforms. - A mapping will be inferred if None is provided - :param norm_mappings: Specifies layers within a model to target for norm fusing. - A mapping will be inferred if None is provided - :param transform_config: Optional transform config for overriding provided arguments - """ - - rotations: List[SpinquantRotation] = Field( - default_factory=lambda: ["R1", "R2"], exclude=True - ) - transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( - default="hadamard", exclude=True - ) - randomize: bool = Field(default=False, exclude=True) - learnable: bool = Field(default=False, exclude=True) - - # norm mappings separate from spinquant mappings to allow users to - # override spinquant mappings with transform_config without overriding norms - mappings: Optional[SpinQuantMapping] = Field( - default=None, - repr=False, - exclude=True, - ) - norm_mappings: Optional[List[NormMapping]] = Field( - default=None, - repr=False, - exclude=True, - ) - - # optional override for more fine-grained control - # also included in recipe serialization - transform_config: Optional[TransformConfig] = Field(default=None, repr=False) - - @field_validator("randomize", "learnable", mode="before") - def validate_not_implemented(cls, value, info: ValidationInfo): - raise NotImplementedError(f"{info.field_name} is not supported right now") - - @field_validator("rotations", mode="before") - def validate_rotations(cls, value): - if isinstance(value, Iterable): - return tuple(v.upper() for v in value) - return value - - def on_initialize(self, state: State, **kwargs) -> bool: - if self.transform_config is not None: - return True - - self.mappings = infer_mapping_from_model(state.model) - self.norm_mappings = infer_norm_mapping_from_model(state.model) - - config_groups = {} - if SpinquantRotation.R1 in self.rotations: - config_groups["R1"] = self._create_r1_scheme() - - if SpinquantRotation.R2 in self.rotations: - config_groups["R2"] = self._create_r2_scheme(state.model) - - if SpinquantRotation.R3 in self.rotations: - config_groups["R3"] = self._create_r3_scheme() - - if SpinquantRotation.R4 in self.rotations: - config_groups["R4"] = self._create_r4_scheme() - - self.transform_config = TransformConfig(config_groups=config_groups) - - return True - - def on_start(self, state: State, event: Event, **kwargs): - self.started_ = True - - # needs to happen after the model has been hooked to execute on the GPU - # otherwise we're applying weight transforms on CPU - self._prenormalize_embeddings(state.model) - self._fuse_norms(state.model) - apply_transform_config(state.model, self.transform_config) - - def on_event(self, state: State, event: Event, **kwargs): - if event.type_ == EventType.CALIBRATION_EPOCH_START: - if not self.started_: - self.on_start(state, None) - - elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: - pass - - elif event.type_ == EventType.CALIBRATION_EPOCH_END: - if not self.ended_: - self.on_end(state, None) - - def on_end(self, state: State, event: Event, **kwargs): - self.ended_ = True - - def on_finalize(self, state: State, **kwargs) -> bool: - if not self.ended_: - self.on_end(state, None) - - return True - - def _prenormalize_embeddings(self, model: PreTrainedModel): - for _, embedding in match_named_modules( - model, [self.mappings.embedding], warn_on_fail=True - ): - normalize_embedding(embedding) - - def _fuse_norms(self, model: PreTrainedModel): - for mapping in self.norm_mappings: - for norm, *linears in match_modules_set( - model, (mapping.norm, *mapping.linears) - ): - fuse_norm_linears(norm, linears) - - def _create_r1_scheme(self) -> TransformScheme: - return TransformScheme( - type=self.transform_type, - randomize=self.randomize, - requires_grad=self.learnable, - apply=[ - TransformArgs( - targets=[ - self.mappings.embedding, - self.mappings.attn_o, - *self.mappings.mlp_out, - ], - location="weight_output", - ), - TransformArgs( - targets=[ - self.mappings.attn_q, - self.mappings.attn_k, - self.mappings.attn_v, - *self.mappings.mlp_in, - self.mappings.lm_head, - ], - location="weight_input", - inverse=True, - ), - ], - ) - - def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: - config = model.config - - if hasattr(config, "head_dim"): - head_dim = config.head_dim - elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): - head_dim = config.hidden_size // config.num_attention_heads - else: - raise NotImplementedError() - - return TransformScheme( - type=self.transform_type, - randomize=self.randomize, - requires_grad=self.learnable, - head_dim=head_dim, - apply=[ - TransformArgs(targets=[self.mappings.attn_v], location="weight_output"), - TransformArgs( - targets=[self.mappings.attn_o], - location="weight_input", - inverse=True, - ), - ], - ) - - def _create_r3_scheme(self) -> TransformScheme: - raise NotImplementedError() - - def _create_r4_scheme(self) -> TransformScheme: - raise NotImplementedError() diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py deleted file mode 100644 index 514d1f109..000000000 --- a/src/llmcompressor/modifiers/transform/spinquant/mappings.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Dict, List, Optional - -from loguru import logger -from pydantic import BaseModel, Field, field_validator -from transformers import PreTrainedModel - -__all__ = ["SpinQuantMapping", "infer_mapping_from_model"] - - -class SpinQuantMapping(BaseModel): - """ - SpinQuant needs to know the entire architecture of the model, - as R1, R2, R3, and R4 rotations need to be applied to specific - layers (https://arxiv.org/pdf/2405.16406 Fig. 1). - - :param embedding: name or regex of embedding layer - :param attn_q: name or regex of q_proj layer in attention block - :param attn_k: name or regex of k_proj layer in attention block - :param attn_v: name or regex of v_proj layer in attention block - :param attn_o: name or regex of o_proj layer in attention block - :param attn_head_dim: head_dim of the attention module, needed - because R2 needs to be applied "head-wisely" to v_proj and - o_proj - :param mlp_in: list of names or regexes for the mlp blocks that - receive the input to the MLP block, usually up_proj and gate_proj - :param mlp_out: list of names or regexes for the mlp blocks that - consitute the output of the MLP block, usually down_proj - """ - - embedding: str - - attn_q: str - attn_k: str - attn_v: str - attn_o: str - attn_head_dim: Optional[int] = Field(default=None) - - mlp_in: List[str] # up_proj, gate_proj - mlp_out: List[str] # down_proj - - lm_head: str - - @field_validator("mlp_in", "mlp_out", mode="before") - def cast_to_list(cls, value): - if isinstance(value, str): - return [value] - - return value - - -_default_mappings = SpinQuantMapping( - embedding="re:.*embed_tokens$", - attn_q="re:.*q_proj$", - attn_k="re:.*k_proj$", - attn_v="re:.*v_proj$", - attn_o="re:.*o_proj$", - mlp_in=["re:.*up_proj$", "re:.*gate_proj$"], - mlp_out="re:.*down_proj$", - lm_head="lm_head", -) - - -SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMapping] = { - "LlamaForCausalLM": _default_mappings, -} - - -def infer_mapping_from_model(model: PreTrainedModel) -> SpinQuantMapping: - architecture = model.__class__.__name__ - if architecture not in SPINQUANT_MAPPING_REGISTRY: - logger.info( - f"Unrecognized model architecture {architecture}. " - "Falling back to default mappings" - ) - - return SPINQUANT_MAPPING_REGISTRY.get(architecture, _default_mappings) diff --git a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py deleted file mode 100644 index e60ac0d1a..000000000 --- a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import Dict, List - -from loguru import logger -from pydantic import BaseModel, field_validator -from transformers import PreTrainedModel - -__all__ = ["infer_norm_mapping_from_model"] - - -class NormMapping(BaseModel): - """ - SpinQuant needs to know where every norm layer exists in the model, - as well as all the subsequent Linear layers the norm passes into. - This is because the norm layer weights need to normalized before - transforms can be fused into Linear layers. - - :param norm: name or regex that matches norm layer in model - :param linears: list of names or regexes of Linear layers that - receive input from norm. - """ - - norm: str - linears: List[str] - - @field_validator("linears", mode="before") - def cast_to_list(cls, value): - if isinstance(value, str): - return [value] - - return value - - -_default_mappings = [ - NormMapping( - norm="re:.*input_layernorm$", - linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], - ), - NormMapping( - norm="re:.*post_attention_layernorm$", - linears=["re:.*up_proj$", "re:.*gate_proj$"], - ), - NormMapping( - norm="model.norm", - linears=["lm_head"], - ), -] - -NORM_MAPPING_REGISTRY: Dict[str, NormMapping] = { - "LlamaForCausalLM": _default_mappings, -} - - -def infer_norm_mapping_from_model(model: PreTrainedModel) -> List[NormMapping]: - architecture = model.__class__.__name__ - if architecture not in NORM_MAPPING_REGISTRY: - logger.info( - f"Unrecognized model architecture {architecture}. " - "Falling back to default mappings" - ) - - return NORM_MAPPING_REGISTRY.get(architecture, _default_mappings) diff --git a/tests/llmcompressor/modifiers/transform/test_correctness.py b/tests/llmcompressor/modifiers/transform/test_correctness.py index 660bab0ef..6d91f630c 100644 --- a/tests/llmcompressor/modifiers/transform/test_correctness.py +++ b/tests/llmcompressor/modifiers/transform/test_correctness.py @@ -1,34 +1,35 @@ import pytest import torch -from compressed_tensors.transform import apply_transform_config from transformers import AutoModelForCausalLM -from llmcompressor.modifiers.transform.template.quip import QUIP +from llmcompressor.core import State +from llmcompressor.modifiers.transform import QuIPModifier +from tests.testing_utils import requires_gpu +@requires_gpu @pytest.mark.parametrize( - "dtype,exp_max,exp_mse", + "dtype,exp_mse", [ - ( - torch.bfloat16, - 1.1, - 0.012, - ), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501 - (torch.float32, 4e-4, 2e-9), + (torch.bfloat16, 1e-2), + (torch.float32, 1e-9), ], ) -def test_apply_correctness(dtype, exp_max, exp_mse): +def test_apply_correctness(dtype, exp_mse): model = AutoModelForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3-8B-Instruct", device_map="cuda", torch_dtype=dtype ) + state = State(model=model) + modifier = QuIPModifier(transform_type="random-hadamard") input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()} with torch.no_grad(): true_output = model(**input) - apply_transform_config(model, QUIP) + modifier.on_initialize(state) + modifier.on_start(state, None) + with torch.no_grad(): output = model(**input) - assert torch.max(true_output.logits - output.logits) <= exp_max assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse From 9e975d3ef0ddda89beb1809843549d53d017ae1b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 5 Aug 2025 11:59:15 -0400 Subject: [PATCH 43/54] WIP: janice network issues Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/transform/quip/base.py | 1 + .../modifiers/transform/test_correctness.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py index d34245055..9b9deb33e 100644 --- a/src/llmcompressor/modifiers/transform/quip/base.py +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -45,6 +45,7 @@ class QuIPModifier(Modifier): ) randomize: bool = Field(default=False, exclude=True) learnable: bool = Field(default=False, exclude=True) + precision: ignore: Union[str, List[str]] = Field(default="lm_head", exclude=True) # optional override for more fine-grained control diff --git a/tests/llmcompressor/modifiers/transform/test_correctness.py b/tests/llmcompressor/modifiers/transform/test_correctness.py index 6d91f630c..32475d4ee 100644 --- a/tests/llmcompressor/modifiers/transform/test_correctness.py +++ b/tests/llmcompressor/modifiers/transform/test_correctness.py @@ -1,3 +1,4 @@ +import os import pytest import torch from transformers import AutoModelForCausalLM @@ -8,16 +9,20 @@ @requires_gpu +# @pytest.mark.skipif( +# (not os.getenv("HF_TOKEN")), +# reason="Skipping tracing tests requiring gated model access", +# ) @pytest.mark.parametrize( "dtype,exp_mse", [ - (torch.bfloat16, 1e-2), - (torch.float32, 1e-9), + (torch.bfloat16, 5e-3), + (torch.float32, 5e-11), ], ) def test_apply_correctness(dtype, exp_mse): model = AutoModelForCausalLM.from_pretrained( - "meta-llama/Meta-Llama-3-8B-Instruct", device_map="cuda", torch_dtype=dtype + "meta-llama/Llama-3.2-1B-Instruct", device_map="cuda", torch_dtype=dtype ) state = State(model=model) modifier = QuIPModifier(transform_type="random-hadamard") @@ -32,4 +37,5 @@ def test_apply_correctness(dtype, exp_mse): with torch.no_grad(): output = model(**input) + print(torch.nn.MSELoss()(output.logits, true_output.logits)) assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse From 5392b2b7c3d8484900421462035f2a90bb86d43a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 5 Aug 2025 16:43:26 +0000 Subject: [PATCH 44/54] cleanup Signed-off-by: Kyle Sayers --- src/llmcompressor/modeling/fuse.py | 63 +++++++++---------- .../modifiers/transform/quip/base.py | 18 +++--- tests/llmcompressor/modeling/test_fuse.py | 6 +- .../transform/{ => quip}/test_correctness.py | 23 ++++--- .../transform/quip/test_serialization.py | 7 +++ 5 files changed, 64 insertions(+), 53 deletions(-) rename tests/llmcompressor/modifiers/transform/{ => quip}/test_correctness.py (59%) create mode 100644 tests/llmcompressor/modifiers/transform/quip/test_serialization.py diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py index e59be596c..a8168a31f 100644 --- a/src/llmcompressor/modeling/fuse.py +++ b/src/llmcompressor/modeling/fuse.py @@ -6,58 +6,55 @@ get_execution_device, update_offload_parameter, ) -from transformers.models.llama.modeling_llama import LlamaRMSNorm -__all__ = ["normalize_embedding", "fuse_norm_linears"] +__all__ = ["center_embeddings", "fuse_norm_linears"] PRECISION = torch.float64 -def normalize_embedding(embedding: torch.nn.Module): +def center_embeddings(embedding: torch.nn.Module): """ - Normalize each embedding to have a mean of zero + Shift each embedding to have a mean of zero :param embedding: embedding module containing embeddings to center """ - if isinstance(embedding, (torch.nn.Embedding)): - with align_module_device(embedding): - weight_dtype = embedding.weight.dtype - weight = embedding.weight.to(PRECISION) - new_weight = weight - weight.mean(dim=-1, keepdim=True) - new_weight = new_weight.to(weight_dtype) + if not hasattr(embedding, "weight"): + raise ValueError(f"Cannot fuse norm of type {type(embedding)}") - update_offload_parameter(embedding, "weight", new_weight) + with align_module_device(embedding): + weight_dtype = embedding.weight.dtype + weight = embedding.weight.to(PRECISION) + new_weight = weight - weight.mean(dim=-1, keepdim=True) + new_weight = new_weight.to(weight_dtype) - else: - raise ValueError(f"Cannot normalize embedding of type {type(embedding)}") + update_offload_parameter(embedding, "weight", new_weight) def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]): """ - Fuse a norm layer into subsequent linear layers. This useful for ensuring transform - invariance between norm and linear layers. + Fuse the scaling operation of norm layer into subsequent linear layers. + This useful for ensuring transform invariance between norm and linear layers. - Note that a model cannot be properly trained after its norms have been fused + Note that unitary transforms (rotation) commute with normalization, but not scaling :param norm: norm layer whose weight will be fused into subsequent linears :param linears: linear layers which directly follow the norm layer """ - if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm, torch.nn.LayerNorm)): - for linear in linears: - # NOTE: spinquant does this op in float64 - exec_device = get_execution_device(norm) - with align_module_device(norm, exec_device), align_module_device( - linear, exec_device - ): - weight_dtype = linear.weight.dtype - new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION) - new_weight = new_weight.to(weight_dtype) - - update_offload_parameter(linear, "weight", new_weight) - - new_norm_weight = torch.ones_like(norm.weight, device="cpu") - update_offload_parameter(norm, "weight", new_norm_weight) - - else: + if not hasattr(norm, "weight"): raise ValueError(f"Cannot fuse norm of type {type(norm)}") + + for linear in linears: + # NOTE: spinquant does this op in float64 + exec_device = get_execution_device(norm) + with align_module_device(norm, exec_device), align_module_device( + linear, exec_device + ): + weight_dtype = linear.weight.dtype + new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION) + new_weight = new_weight.to(weight_dtype) + + update_offload_parameter(linear, "weight", new_weight) + + new_norm_weight = torch.ones_like(norm.weight, device="cpu") + update_offload_parameter(norm, "weight", new_norm_weight) diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py index 9b9deb33e..5d89dab95 100644 --- a/src/llmcompressor/modifiers/transform/quip/base.py +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -1,11 +1,13 @@ from typing import List, Literal, Optional, Union +import torch from compressed_tensors.transform import ( TransformArgs, TransformConfig, TransformScheme, apply_transform_config, ) +from compressed_tensors.utils import TorchDtype from pydantic import Field, ValidationInfo, field_validator from llmcompressor.core import Event, EventType, State @@ -36,17 +38,19 @@ class QuIPModifier(Modifier): `"random-matrix"` has the greatest performance cost, but supports any size :param randomize: If true, create distinct transforms for each application :param learnable: If true, attach gradients to transform weights for training + :param precision: Precision at which all transforms should be applied. This applies + to both weight fusing and online rotations :param ignore: Modules to ignore when attaching transforms :param transform_config: Optional transform config for overriding provided arguments """ transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( - default="hadamard", exclude=True + default="random-hadamard" ) - randomize: bool = Field(default=False, exclude=True) - learnable: bool = Field(default=False, exclude=True) - precision: - ignore: Union[str, List[str]] = Field(default="lm_head", exclude=True) + randomize: bool = Field(default=False) + learnable: bool = Field(default=False) + precision: TorchDtype = Field(default=torch.float64) + ignore: Union[str, List[str]] = Field(default="lm_head") # optional override for more fine-grained control # also included in recipe serialization @@ -105,13 +109,13 @@ def _create_config(self) -> TransformConfig: TransformArgs( targets=["Linear"], location="weight_input", - # location="input", inverse=True, ignore=self.ignore, ), ], randomize=self.randomize, requires_grad=self.learnable, + precision=self.precision, ), "u": TransformScheme( type=self.transform_type, @@ -119,7 +123,6 @@ def _create_config(self) -> TransformConfig: TransformArgs( targets=["Linear"], location="weight_output", - # location="output", ignore=self.ignore, ), TransformArgs( @@ -131,6 +134,7 @@ def _create_config(self) -> TransformConfig: ], randomize=self.randomize, requires_grad=self.learnable, + precision=self.precision, ), } ) diff --git a/tests/llmcompressor/modeling/test_fuse.py b/tests/llmcompressor/modeling/test_fuse.py index 5798f692c..005d89f99 100644 --- a/tests/llmcompressor/modeling/test_fuse.py +++ b/tests/llmcompressor/modeling/test_fuse.py @@ -1,13 +1,13 @@ import pytest import torch -from llmcompressor.modeling.fuse import fuse_norm_linears, normalize_embedding +from llmcompressor.modeling.fuse import center_embeddings, fuse_norm_linears @pytest.mark.unit -def test_normalize_embedding(): +def test_center_embeddings(): embedding = torch.nn.Embedding(10, 10) - normalize_embedding(embedding) + center_embeddings(embedding) assert torch.allclose( embedding.weight.mean(dim=1), torch.zeros(embedding.num_embeddings), atol=1e-5 diff --git a/tests/llmcompressor/modifiers/transform/test_correctness.py b/tests/llmcompressor/modifiers/transform/quip/test_correctness.py similarity index 59% rename from tests/llmcompressor/modifiers/transform/test_correctness.py rename to tests/llmcompressor/modifiers/transform/quip/test_correctness.py index 32475d4ee..276060b6b 100644 --- a/tests/llmcompressor/modifiers/transform/test_correctness.py +++ b/tests/llmcompressor/modifiers/transform/quip/test_correctness.py @@ -1,4 +1,5 @@ import os + import pytest import torch from transformers import AutoModelForCausalLM @@ -9,23 +10,25 @@ @requires_gpu -# @pytest.mark.skipif( -# (not os.getenv("HF_TOKEN")), -# reason="Skipping tracing tests requiring gated model access", -# ) +@pytest.mark.skipif( + (not os.getenv("HF_TOKEN")), + reason="Skipping tracing tests requiring gated model access", +) @pytest.mark.parametrize( - "dtype,exp_mse", + "model_dtype,precision,exp_mse", [ - (torch.bfloat16, 5e-3), - (torch.float32, 5e-11), + (torch.bfloat16, torch.bfloat16, 5e-3), # 0.0019 + (torch.bfloat16, torch.float32, 5e-3), # 0.0022 + (torch.float32, torch.float32, 5e-10), # 1.0777e-10 + (torch.float32, torch.float64, 5e-11), # 2.6632e-11 ], ) -def test_apply_correctness(dtype, exp_mse): +def test_apply_correctness(model_dtype, precision, exp_mse): model = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-3.2-1B-Instruct", device_map="cuda", torch_dtype=dtype + "meta-llama/Llama-3.2-1B-Instruct", device_map="cuda", torch_dtype=model_dtype ) state = State(model=model) - modifier = QuIPModifier(transform_type="random-hadamard") + modifier = QuIPModifier(transform_type="random-hadamard", precision=precision) input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()} with torch.no_grad(): diff --git a/tests/llmcompressor/modifiers/transform/quip/test_serialization.py b/tests/llmcompressor/modifiers/transform/quip/test_serialization.py new file mode 100644 index 000000000..3dc682728 --- /dev/null +++ b/tests/llmcompressor/modifiers/transform/quip/test_serialization.py @@ -0,0 +1,7 @@ +from llmcompressor.modifiers.transform import QuIPModifier + + +def test_reload(): + modifier = QuIPModifier(transform_type="hadamard") + dump = modifier.model_dump() + assert QuIPModifier.model_validate(dump) == modifier From 5015d71014b6b619426821e92216524fd90d6f92 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 5 Aug 2025 16:50:12 +0000 Subject: [PATCH 45/54] add disclaimer Signed-off-by: Kyle Sayers --- examples/transform/quip_example.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py index 26f76f4ec..01fd39784 100644 --- a/examples/transform/quip_example.py +++ b/examples/transform/quip_example.py @@ -1,3 +1,8 @@ +""" +NOTE: models produced using this example will not be capable of running in vLLM. +See https://github.com/vllm-project/vllm/pull/22219 for progress updates +""" + from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer From 6311eef6c291dd8f3a9c730a91d174878ef5c589 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 5 Aug 2025 16:55:13 +0000 Subject: [PATCH 46/54] more disclaimer Signed-off-by: Kyle Sayers --- examples/transform/quip_example.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py index 01fd39784..72a6f3830 100644 --- a/examples/transform/quip_example.py +++ b/examples/transform/quip_example.py @@ -1,5 +1,6 @@ """ NOTE: models produced using this example will not be capable of running in vLLM. +You will also need to install `transformers>=4.56` or install from source See https://github.com/vllm-project/vllm/pull/22219 for progress updates """ From 2042eb615ab8eb36c25479ae378cce8445440341 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 5 Aug 2025 17:11:23 +0000 Subject: [PATCH 47/54] update disclaimer Signed-off-by: Kyle Sayers --- examples/transform/quip_example.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py index 72a6f3830..9a0cacee7 100644 --- a/examples/transform/quip_example.py +++ b/examples/transform/quip_example.py @@ -1,17 +1,32 @@ """ -NOTE: models produced using this example will not be capable of running in vLLM. -You will also need to install `transformers>=4.56` or install from source -See https://github.com/vllm-project/vllm/pull/22219 for progress updates +WARNING: This example requires the following minimum versions: + * compressed-tensors>=0.10.3.dev + * transformers>=4.56.dev +Note that (you may need to install from source) + +Models produced by this example will not be runnable in vLLM without +the following changes: https://github.com/vllm-project/vllm/pull/22219 """ from datasets import load_dataset +from packaging import version from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.utils.import_utils import _is_package_available from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.modifiers.transform import QuIPModifier from llmcompressor.utils import dispatch_for_generation +# check correct versioning +_, ct_version = _is_package_available("compressed_tensors", return_version=True) +_, tfms_version = _is_package_available("transformers", return_version=True) +if version.parse(ct_version) < version.parse("0.10.3.dev"): + print(version.parse(ct_version)) + raise ValueError("Please install compressed-tensors>=0.10.3 or from source") +if version.parse(tfms_version) < version.parse("4.56.dev"): + raise ValueError("Please install transformers>=4.56 or from source") + # Select model and load it. MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" From ec83dc4e231e9884ce52c84c6f9065f436e40983 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 6 Aug 2025 16:21:53 +0000 Subject: [PATCH 48/54] remove extra file Signed-off-by: Kyle Sayers --- compress_model.py | 60 ----------------------------------------------- 1 file changed, 60 deletions(-) delete mode 100644 compress_model.py diff --git a/compress_model.py b/compress_model.py deleted file mode 100644 index fa67bead0..000000000 --- a/compress_model.py +++ /dev/null @@ -1,60 +0,0 @@ -# python3 compress_model.py --model_id meta-llama/Llama-3.2-1B-Instruct --transform_type random-hadamard -import argparse -from transformers import AutoModelForCausalLM, AutoTokenizer - -from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import QuantizationModifier -from llmcompressor.modifiers.transform import SpinQuantModifier -from llmcompressor.utils import dispatch_for_generation - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--model_id", type=str, help="Model stub to compress") - parser.add_argument("--transform_type", type=str, default=None, help="Type of transform used in SpinQuantModifier") - parser.add_argument("--scheme", type=str, default=None, help="Quantization scheme (e.g. W4A16)") - return parser.parse_args() - -if __name__ == "__main__": - args = parse_args() - - # Select model and load it. - MODEL_ID = args.model_id - model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - - # Select number of samples. 512 samples is a good place to start. - # Increasing the number of samples can improve accuracy. - NUM_CALIBRATION_SAMPLES = 512 - MAX_SEQUENCE_LENGTH = 2048 - - # Configure the quantization algorithm to run. - recipe = [] - if args.transform_type: - recipe.append(SpinQuantModifier(rotations=["R1", "R2"], transform_type=args.transform_type)) - - if args.scheme: - recipe.append(QuantizationModifier(targets="Linear", scheme=args.scheme, ignore=["lm_head"])) - - # Apply algorithms. - oneshot( - model=model, - recipe=recipe, - dataset="ultrachat_200k", - splits={"calibration": f"train_sft[:{NUM_CALIBRATION_SAMPLES}]"}, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - ) - - # Confirm generations of the quantized model look sane. - print("\n\n") - print("========== SAMPLE GENERATION ==============") - dispatch_for_generation(model) - input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") - output = model.generate(input_ids, max_new_tokens=100) - print(tokenizer.decode(output[0])) - print("==========================================\n\n") - - # Save to disk compressed. - SAVE_DIR = MODEL_ID.split("/")[1] + f"-{args.transform_type}-{args.scheme}" - model.save_pretrained(SAVE_DIR, save_compressed=True) - tokenizer.save_pretrained(SAVE_DIR) From 884db4bd720a7babd203dcac040e7d0cc8d2e492 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 7 Aug 2025 00:47:35 -0400 Subject: [PATCH 49/54] fix style Signed-off-by: Kyle Sayers --- examples/quantization_w8a8_fp8/fp8_block_example.py | 4 +++- examples/transform/quip_example.py | 1 - src/llmcompressor/modifiers/transform/quip/base.py | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/quantization_w8a8_fp8/fp8_block_example.py b/examples/quantization_w8a8_fp8/fp8_block_example.py index e977110ad..03a4c0bd6 100644 --- a/examples/quantization_w8a8_fp8/fp8_block_example.py +++ b/examples/quantization_w8a8_fp8/fp8_block_example.py @@ -16,7 +16,9 @@ # * quantize the weights to fp8 with per channel via ptq # * quantize the activations to fp8 with dynamic per token recipe = QuantizationModifier( - targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head", "re:.*mlp.gate$"], + targets="Linear", + scheme="FP8_BLOCK", + ignore=["lm_head", "re:.*mlp.gate$"], ) # Apply quantization. diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py index 9a0cacee7..773295722 100644 --- a/examples/transform/quip_example.py +++ b/examples/transform/quip_example.py @@ -22,7 +22,6 @@ _, ct_version = _is_package_available("compressed_tensors", return_version=True) _, tfms_version = _is_package_available("transformers", return_version=True) if version.parse(ct_version) < version.parse("0.10.3.dev"): - print(version.parse(ct_version)) raise ValueError("Please install compressed-tensors>=0.10.3 or from source") if version.parse(tfms_version) < version.parse("4.56.dev"): raise ValueError("Please install transformers>=4.56 or from source") diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py index 5d89dab95..1d753bc15 100644 --- a/src/llmcompressor/modifiers/transform/quip/base.py +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -19,8 +19,8 @@ class QuIPModifier(Modifier): """ Implements the transforms according to - [QuIP#: Even Better LLM Quantization with Hadamard Incoherence and Lattice Codebooks](https://arxiv.org/pdf/2402.04396) # noqa: E501 - [QuIP: 2-Bit Quantization of Large Language Models With Guarantees](https://arxiv.org/abs/2307.13304) # noqa: E501 + [QuIP#: Even Better LLM Quantization with Hadamard Incoherence and Lattice Codebooks](https://arxiv.org/pdf/2402.04396) + [QuIP: 2-Bit Quantization of Large Language Models With Guarantees](https://arxiv.org/abs/2307.13304) Transforms (rotations) are extra layers added to a model which reduce the accuracy loss induced by quantization. This is achieved through "rotating" weights and @@ -42,7 +42,7 @@ class QuIPModifier(Modifier): to both weight fusing and online rotations :param ignore: Modules to ignore when attaching transforms :param transform_config: Optional transform config for overriding provided arguments - """ + """ # noqa: E501 transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( default="random-hadamard" From 1f5ce4c05e7e61a04edffb1ff630d8edfa07793c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 7 Aug 2025 18:31:09 -0400 Subject: [PATCH 50/54] update example Signed-off-by: Kyle Sayers --- examples/transform/quip_example.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py index 773295722..bd497e124 100644 --- a/examples/transform/quip_example.py +++ b/examples/transform/quip_example.py @@ -1,31 +1,16 @@ """ -WARNING: This example requires the following minimum versions: - * compressed-tensors>=0.10.3.dev - * transformers>=4.56.dev -Note that (you may need to install from source) - -Models produced by this example will not be runnable in vLLM without +NOTE: Models produced by this example will not be runnable in vLLM without the following changes: https://github.com/vllm-project/vllm/pull/22219 """ from datasets import load_dataset -from packaging import version from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.utils.import_utils import _is_package_available from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.modifiers.transform import QuIPModifier from llmcompressor.utils import dispatch_for_generation -# check correct versioning -_, ct_version = _is_package_available("compressed_tensors", return_version=True) -_, tfms_version = _is_package_available("transformers", return_version=True) -if version.parse(ct_version) < version.parse("0.10.3.dev"): - raise ValueError("Please install compressed-tensors>=0.10.3 or from source") -if version.parse(tfms_version) < version.parse("4.56.dev"): - raise ValueError("Please install transformers>=4.56 or from source") - # Select model and load it. MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" From 7324f4bec1ad8bf13402d16201b530279a98854c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 8 Aug 2025 00:36:48 -0400 Subject: [PATCH 51/54] move tests files Signed-off-by: Kyle Sayers --- .../transform/quip/test_correctness.py | 44 ---------------- .../transform/quip/test_serialization.py | 7 --- .../modifiers/transform/test_correctness.py | 52 +++++++++++++++++++ .../modifiers/transform/test_serialization.py | 10 ++++ 4 files changed, 62 insertions(+), 51 deletions(-) delete mode 100644 tests/llmcompressor/modifiers/transform/quip/test_correctness.py delete mode 100644 tests/llmcompressor/modifiers/transform/quip/test_serialization.py create mode 100644 tests/llmcompressor/modifiers/transform/test_correctness.py create mode 100644 tests/llmcompressor/modifiers/transform/test_serialization.py diff --git a/tests/llmcompressor/modifiers/transform/quip/test_correctness.py b/tests/llmcompressor/modifiers/transform/quip/test_correctness.py deleted file mode 100644 index 276060b6b..000000000 --- a/tests/llmcompressor/modifiers/transform/quip/test_correctness.py +++ /dev/null @@ -1,44 +0,0 @@ -import os - -import pytest -import torch -from transformers import AutoModelForCausalLM - -from llmcompressor.core import State -from llmcompressor.modifiers.transform import QuIPModifier -from tests.testing_utils import requires_gpu - - -@requires_gpu -@pytest.mark.skipif( - (not os.getenv("HF_TOKEN")), - reason="Skipping tracing tests requiring gated model access", -) -@pytest.mark.parametrize( - "model_dtype,precision,exp_mse", - [ - (torch.bfloat16, torch.bfloat16, 5e-3), # 0.0019 - (torch.bfloat16, torch.float32, 5e-3), # 0.0022 - (torch.float32, torch.float32, 5e-10), # 1.0777e-10 - (torch.float32, torch.float64, 5e-11), # 2.6632e-11 - ], -) -def test_apply_correctness(model_dtype, precision, exp_mse): - model = AutoModelForCausalLM.from_pretrained( - "meta-llama/Llama-3.2-1B-Instruct", device_map="cuda", torch_dtype=model_dtype - ) - state = State(model=model) - modifier = QuIPModifier(transform_type="random-hadamard", precision=precision) - - input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()} - with torch.no_grad(): - true_output = model(**input) - - modifier.on_initialize(state) - modifier.on_start(state, None) - - with torch.no_grad(): - output = model(**input) - - print(torch.nn.MSELoss()(output.logits, true_output.logits)) - assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse diff --git a/tests/llmcompressor/modifiers/transform/quip/test_serialization.py b/tests/llmcompressor/modifiers/transform/quip/test_serialization.py deleted file mode 100644 index 3dc682728..000000000 --- a/tests/llmcompressor/modifiers/transform/quip/test_serialization.py +++ /dev/null @@ -1,7 +0,0 @@ -from llmcompressor.modifiers.transform import QuIPModifier - - -def test_reload(): - modifier = QuIPModifier(transform_type="hadamard") - dump = modifier.model_dump() - assert QuIPModifier.model_validate(dump) == modifier diff --git a/tests/llmcompressor/modifiers/transform/test_correctness.py b/tests/llmcompressor/modifiers/transform/test_correctness.py new file mode 100644 index 000000000..7e2223d88 --- /dev/null +++ b/tests/llmcompressor/modifiers/transform/test_correctness.py @@ -0,0 +1,52 @@ +import os + +import pytest +import torch +from transformers import AutoModelForCausalLM + +from llmcompressor.core import State +from llmcompressor.modifiers.transform import QuIPModifier +from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( + untie_word_embeddings, +) +from tests.testing_utils import requires_gpu + + +@requires_gpu +@pytest.mark.skipif( + (not os.getenv("HF_TOKEN")), + reason="Skipping tracing tests requiring gated model access", +) +@pytest.mark.parametrize( + "modifier,model_dtype,precision,exp_mse", + [ + (QuIPModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0019 + (QuIPModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0022 + (QuIPModifier, torch.float32, torch.float32, 5e-10), # 1.0e-10 + (QuIPModifier, torch.float32, torch.float64, 5e-11), # 2.7e-11 + # (SpinQuantModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0043 + # (SpinQuantModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0033 + # (SpinQuantModifier, torch.float32, torch.float32, 5e-4), # 4e-4 + # (SpinQuantModifier, torch.float32, torch.float64, 5e-4), # 4e-4 + ], +) +def test_apply_correctness(modifier, model_dtype, precision, exp_mse): + model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B-Instruct", device_map="cuda", torch_dtype=model_dtype + ) + untie_word_embeddings(model) + + state = State(model=model) + modifier = modifier(transform_type="random-hadamard", precision=precision) + + input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()} + with torch.no_grad(): + true_output = model(**input) + + modifier.on_initialize(state) + modifier.on_start(state, None) + + with torch.no_grad(): + output = model(**input) + + assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse diff --git a/tests/llmcompressor/modifiers/transform/test_serialization.py b/tests/llmcompressor/modifiers/transform/test_serialization.py new file mode 100644 index 000000000..eba730dba --- /dev/null +++ b/tests/llmcompressor/modifiers/transform/test_serialization.py @@ -0,0 +1,10 @@ +import pytest + +from llmcompressor.modifiers.transform import QuIPModifier + + +@pytest.mark.parametrize("modifier", [QuIPModifier]) +def test_reload(modifier): + instance = modifier(transform_type="hadamard") + dump = instance.model_dump() + assert modifier.model_validate(dump) == instance From 7d34cca59c234cf9a5951bac24bb84ddefdae614 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 8 Aug 2025 00:53:47 -0400 Subject: [PATCH 52/54] remove calib dataset, add note Signed-off-by: Kyle Sayers --- examples/transform/quip_example.py | 59 +++--------------------------- 1 file changed, 5 insertions(+), 54 deletions(-) diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py index bd497e124..61a46866b 100644 --- a/examples/transform/quip_example.py +++ b/examples/transform/quip_example.py @@ -1,9 +1,8 @@ """ NOTE: Models produced by this example will not be runnable in vLLM without -the following changes: https://github.com/vllm-project/vllm/pull/22219 +the following changes: https://github.com/vllm-project/vllm/pull/22486 """ -from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot @@ -12,53 +11,12 @@ from llmcompressor.utils import dispatch_for_generation # Select model and load it. +# NOTE: because the datafree pipeline is being used in this +# example, you can use additional GPUs to support larger models MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" - -model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, - torch_dtype="auto", -) +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) -# Select calibration dataset. -DATASET_ID = "HuggingFaceH4/ultrachat_200k" -DATASET_SPLIT = "train_sft" - -# Select number of samples. 512 samples is a good place to start. -# Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 2048 - -# Load dataset and preprocess. -ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") -ds = ds.shuffle(seed=42) - - -def preprocess(example): - return { - "text": tokenizer.apply_chat_template( - example["messages"], - tokenize=False, - ) - } - - -ds = ds.map(preprocess) - - -# Tokenize inputs. -def tokenize(sample): - return tokenizer( - sample["text"], - padding=False, - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - add_special_tokens=False, - ) - - -ds = ds.map(tokenize, remove_columns=ds.column_names) - # Configure the quantization algorithm to run. # * apply spinquant transforms to model in order to make quantization easier # * quantize the weights to 4 bit with GPTQ with a group size 128 @@ -68,14 +26,7 @@ def tokenize(sample): ] # Apply algorithms. -oneshot( - model=model, - recipe=recipe, - dataset=ds, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - pipeline="datafree", -) +oneshot(model=model, recipe=recipe, pipeline="datafree") # Confirm generations of the quantized model look sane. print("\n\n") From 972f59f56cf297cf29cee0be1c5c579fad90f360 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 11 Aug 2025 17:21:28 -0400 Subject: [PATCH 53/54] add targets field Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/transform/quip/base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py index 1d753bc15..320ab6df0 100644 --- a/src/llmcompressor/modifiers/transform/quip/base.py +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -47,6 +47,7 @@ class QuIPModifier(Modifier): transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( default="random-hadamard" ) + targets: Union[List[str], str] = Field(default="str") randomize: bool = Field(default=False) learnable: bool = Field(default=False) precision: TorchDtype = Field(default=torch.float64) @@ -102,12 +103,12 @@ def _create_config(self) -> TransformConfig: type=self.transform_type, apply=[ TransformArgs( - targets=["Linear"], + targets=self.targets, location="input", # non-mergable ignore=self.ignore, ), TransformArgs( - targets=["Linear"], + targets=self.targets, location="weight_input", inverse=True, ignore=self.ignore, @@ -121,12 +122,12 @@ def _create_config(self) -> TransformConfig: type=self.transform_type, apply=[ TransformArgs( - targets=["Linear"], + targets=self.targets, location="weight_output", ignore=self.ignore, ), TransformArgs( - targets=["Linear"], + targets=self.targets, location="output", # non-mergable inverse=True, ignore=self.ignore, From 1eb6d0919dd64769d7997d4a9357b98cec7b712e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 14 Aug 2025 15:34:43 -0400 Subject: [PATCH 54/54] update docstrings Signed-off-by: Kyle Sayers --- examples/transform/quip_example.py | 2 +- src/llmcompressor/modifiers/transform/quip/base.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py index 61a46866b..2c989b2d7 100644 --- a/examples/transform/quip_example.py +++ b/examples/transform/quip_example.py @@ -19,7 +19,7 @@ # Configure the quantization algorithm to run. # * apply spinquant transforms to model in order to make quantization easier -# * quantize the weights to 4 bit with GPTQ with a group size 128 +# * quantize the weights to 4 bit with a group size 128 recipe = [ QuIPModifier(transform_type="random-hadamard"), QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py index 320ab6df0..42abbf634 100644 --- a/src/llmcompressor/modifiers/transform/quip/base.py +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -30,6 +30,20 @@ class QuIPModifier(Modifier): QuIP and QuIP# apply transforms to every linear layer, two of which are fused into the model weights and two of which remain as online rotations computed at runtime. + Lifecycle: + - on_initialize + - infer SpinQuantMappings & NormMappings + - as needed, create transform schemes for R1, R2, R3, & R4 + - on_start + - normalize embeddings + - fuse norm layers into subsequent Linear layers + - apply TransformConfig + - fuse transforms into weights for mergeable transforms + - add hooks for online transforms + - on sequential epoch end + - on_end + - on_finalize + :param transform_type: The type of transform to apply to the model. `"hadamard"` has the least performance cost but only supports sizes which are powers of power of two.