Skip to content

[Transform] QuIP Modifier #1648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 64 commits into from
Aug 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
ba617db
wip
kylesayrs Jun 6, 2025
2f5b1c8
use random-hadamard, add correctness tests
kylesayrs Jun 12, 2025
3aa35e7
add correctness test, note that precision makes a large difference
kylesayrs Jun 12, 2025
b6c088e
add on lifecycle methods
brian-dellabetta Jun 23, 2025
d1eb2a1
Merge branch 'main' into kylesayrs/transform-modifier
brian-dellabetta Jul 1, 2025
3207124
TransformModifier with SpinQuant R1&R2
brian-dellabetta Jul 2, 2025
a88ca3c
spinquant and quip_online, running but outputting gibberish
brian-dellabetta Jul 2, 2025
5bd51df
updated example
brian-dellabetta Jul 2, 2025
3c216dd
DummyModel script
brian-dellabetta Jul 8, 2025
bbcdc8c
implement fuse_norm_linears
kylesayrs Jul 10, 2025
bd7f4d5
Merge branch 'kylesayrs/fuse-helpers' into bdellabe/transform-modifier
kylesayrs Jul 10, 2025
f5c2150
R1 working
kylesayrs Jul 11, 2025
dc5c30c
add r2, increase precision
kylesayrs Jul 11, 2025
7172c26
spinquant modifier
kylesayrs Jul 11, 2025
9298e82
remove space
kylesayrs Jul 11, 2025
f77226d
use iterable
kylesayrs Jul 11, 2025
fdb64b5
add rotation validation
kylesayrs Jul 11, 2025
5daa2d5
embedding fusion
kylesayrs Jul 11, 2025
0e9af7b
add missing norm fusion
kylesayrs Jul 12, 2025
fce83be
use norm mappings
kylesayrs Jul 12, 2025
a979f8a
break into separate files
kylesayrs Jul 12, 2025
4cab29e
small cleanup
kylesayrs Jul 12, 2025
f1cc987
cleanup
kylesayrs Jul 14, 2025
a7bb2e2
more cleanup
kylesayrs Jul 14, 2025
0cf0188
make new weight on cpu
kylesayrs Jul 14, 2025
53ea307
standardize, make modifier serializable
kylesayrs Jul 14, 2025
4b4257f
add compress model script
kylesayrs Jul 14, 2025
dc7ac1a
use untie_word_embeddings
kylesayrs Jul 15, 2025
8542f8d
style
kylesayrs Jul 15, 2025
b1e637e
better registery logic
kylesayrs Jul 15, 2025
b44ac81
remove dummy model test (add later)
kylesayrs Jul 15, 2025
7a52b71
docstring
kylesayrs Jul 15, 2025
f4d7ec6
update docstring
kylesayrs Jul 15, 2025
f18d0e8
rename example file
kylesayrs Jul 15, 2025
cec2914
use match_modules_set
kylesayrs Jul 16, 2025
f6c797e
Merge branch 'main' into bdellabe/transform-modifier
brian-dellabetta Jul 16, 2025
0c5c514
unit test fixes
brian-dellabetta Jul 17, 2025
f2ef7cf
style fixes
brian-dellabetta Jul 17, 2025
d0e5bc5
remove hardcoded pipeline logic
brian-dellabetta Jul 24, 2025
31ac8e9
docstrings
brian-dellabetta Jul 24, 2025
a4abb3d
stylefixes
brian-dellabetta Jul 24, 2025
490b987
implement quip
kylesayrs Jul 15, 2025
ac7dbcd
add example, cleanup
kylesayrs Jul 15, 2025
a5d3ddc
update quip example
kylesayrs Aug 1, 2025
a21648d
prepare for merge without spinquant
kylesayrs Aug 4, 2025
9e975d3
WIP: janice network issues
kylesayrs Aug 5, 2025
5392b2b
cleanup
kylesayrs Aug 5, 2025
ddab3d2
Merge remote-tracking branch 'origin' into kylesayrs/transform-quip-m…
kylesayrs Aug 5, 2025
5015d71
add disclaimer
kylesayrs Aug 5, 2025
6311eef
more disclaimer
kylesayrs Aug 5, 2025
2042eb6
update disclaimer
kylesayrs Aug 5, 2025
ec83dc4
remove extra file
kylesayrs Aug 6, 2025
b047914
Merge remote-tracking branch 'origin' into kylesayrs/transform-quip-m…
kylesayrs Aug 7, 2025
884db4b
fix style
kylesayrs Aug 7, 2025
1f5ce4c
update example
kylesayrs Aug 7, 2025
7324f4b
move tests files
kylesayrs Aug 8, 2025
7d34cca
remove calib dataset, add note
kylesayrs Aug 8, 2025
972f59f
add targets field
kylesayrs Aug 11, 2025
62958bb
Merge branch 'main' into kylesayrs/transform-quip-modifier
kylesayrs Aug 12, 2025
12e5ca8
Merge remote-tracking branch 'origin' into kylesayrs/transform-quip-m…
kylesayrs Aug 13, 2025
f86e3ac
Merge branch 'main' into kylesayrs/transform-quip-modifier
dsikka Aug 13, 2025
06d5967
Merge branch 'main' into kylesayrs/transform-quip-modifier
dsikka Aug 14, 2025
1eb6d09
update docstrings
kylesayrs Aug 14, 2025
d6af93e
Merge branch 'main' into kylesayrs/transform-quip-modifier
dsikka Aug 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions examples/transform/quip_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
NOTE: Models produced by this example will not be runnable in vLLM without
the following changes: https://github.com/vllm-project/vllm/pull/22486
"""

from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.modifiers.transform import QuIPModifier
from llmcompressor.utils import dispatch_for_generation

# Select model and load it.
# NOTE: because the datafree pipeline is being used in this
# example, you can use additional GPUs to support larger models
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Configure the quantization algorithm to run.
# * apply spinquant transforms to model in order to make quantization easier
# * quantize the weights to 4 bit with a group size 128
recipe = [
QuIPModifier(transform_type="random-hadamard"),
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
]

# Apply algorithms.
oneshot(model=model, recipe=recipe, pipeline="datafree")

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-quip-w4a16"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
1 change: 1 addition & 0 deletions src/llmcompressor/modifiers/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa

from .quip import QuIPModifier
from .spinquant import SpinQuantModifier
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/transform/quip/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .base import *
155 changes: 155 additions & 0 deletions src/llmcompressor/modifiers/transform/quip/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import List, Literal, Optional, Union

import torch
from compressed_tensors.transform import (
TransformArgs,
TransformConfig,
TransformScheme,
apply_transform_config,
)
from compressed_tensors.utils import TorchDtype
from pydantic import Field, ValidationInfo, field_validator

from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier

__all__ = ["QuIPModifier"]


class QuIPModifier(Modifier):
"""
Implements the transforms according to
[QuIP#: Even Better LLM Quantization with Hadamard Incoherence and Lattice Codebooks](https://arxiv.org/pdf/2402.04396)
[QuIP: 2-Bit Quantization of Large Language Models With Guarantees](https://arxiv.org/abs/2307.13304)

Transforms (rotations) are extra layers added to a model which reduce the accuracy
loss induced by quantization. This is achieved through "rotating" weights and
activations into a space with a smaller dynamic range of values, thus decreasing
the range of scales required for quantization.

QuIP and QuIP# apply transforms to every linear layer, two of which are fused into
the model weights and two of which remain as online rotations computed at runtime.

Lifecycle:
- on_initialize
- infer SpinQuantMappings & NormMappings
- as needed, create transform schemes for R1, R2, R3, & R4
- on_start
- normalize embeddings
- fuse norm layers into subsequent Linear layers
- apply TransformConfig
- fuse transforms into weights for mergeable transforms
- add hooks for online transforms
- on sequential epoch end
- on_end
- on_finalize

:param transform_type: The type of transform to apply to the model.
`"hadamard"` has the least performance cost but only supports sizes which are
powers of power of two.
`"random-hadamard"` has more performance cost, but supports a much larger set of
sizes.
`"random-matrix"` has the greatest performance cost, but supports any size
:param randomize: If true, create distinct transforms for each application
:param learnable: If true, attach gradients to transform weights for training
:param precision: Precision at which all transforms should be applied. This applies
to both weight fusing and online rotations
:param ignore: Modules to ignore when attaching transforms
:param transform_config: Optional transform config for overriding provided arguments
""" # noqa: E501

transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
default="random-hadamard"
)
targets: Union[List[str], str] = Field(default="str")
randomize: bool = Field(default=False)
learnable: bool = Field(default=False)
precision: TorchDtype = Field(default=torch.float64)
ignore: Union[str, List[str]] = Field(default="lm_head")

# optional override for more fine-grained control
# also included in recipe serialization
transform_config: Optional[TransformConfig] = Field(default=None, repr=False)

@field_validator("randomize", "learnable", mode="before")
def validate_not_implemented(cls, value, info: ValidationInfo):
if value:
raise NotImplementedError(f"{info.field_name} is not supported right now")
return value

def on_initialize(self, state: State, **kwargs) -> bool:
if self.transform_config is not None:
return True

self.transform_config = self._create_config()
return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

apply_transform_config(state.model, self.transform_config)

def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.CALIBRATION_EPOCH_START:
if not self.started_:
self.on_start(state, None)

elif event.type_ == EventType.SEQUENTIAL_EPOCH_END:
pass

elif event.type_ == EventType.CALIBRATION_EPOCH_END:
if not self.ended_:
self.on_end(state, None)

def on_end(self, state: State, event: Event, **kwargs):
self.ended_ = True

def on_finalize(self, state: State, **kwargs) -> bool:
if not self.ended_:
self.on_end(state, None)

return True

def _create_config(self) -> TransformConfig:
return TransformConfig(
config_groups={
"v": TransformScheme(
type=self.transform_type,
apply=[
TransformArgs(
targets=self.targets,
location="input", # non-mergable
ignore=self.ignore,
),
TransformArgs(
targets=self.targets,
location="weight_input",
inverse=True,
ignore=self.ignore,
),
],
randomize=self.randomize,
requires_grad=self.learnable,
precision=self.precision,
),
"u": TransformScheme(
type=self.transform_type,
apply=[
TransformArgs(
targets=self.targets,
location="weight_output",
ignore=self.ignore,
),
TransformArgs(
targets=self.targets,
location="output", # non-mergable
inverse=True,
ignore=self.ignore,
),
],
randomize=self.randomize,
requires_grad=self.learnable,
precision=self.precision,
),
}
)
10 changes: 5 additions & 5 deletions tests/llmcompressor/modifiers/transform/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers import AutoModelForCausalLM

from llmcompressor.core import State
from llmcompressor.modifiers.transform import SpinQuantModifier
from llmcompressor.modifiers.transform import QuIPModifier, SpinQuantModifier
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
untie_word_embeddings,
)
Expand All @@ -20,10 +20,10 @@
@pytest.mark.parametrize(
"modifier,model_dtype,precision,exp_mse",
[
# (QuIPModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0019
# (QuIPModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0022
# (QuIPModifier, torch.float32, torch.float32, 5e-10), # 1.0e-10
# (QuIPModifier, torch.float32, torch.float64, 5e-11), # 2.7e-11
(QuIPModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0019
(QuIPModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0022
(QuIPModifier, torch.float32, torch.float32, 5e-10), # 1.0e-10
(QuIPModifier, torch.float32, torch.float64, 5e-11), # 2.7e-11
(SpinQuantModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0030
(SpinQuantModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0029
(SpinQuantModifier, torch.float32, torch.float32, 5e-4), # 4e-4
Expand Down
4 changes: 2 additions & 2 deletions tests/llmcompressor/modifiers/transform/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest

from llmcompressor.modifiers.transform import SpinQuantModifier
from llmcompressor.modifiers.transform import QuIPModifier, SpinQuantModifier


@pytest.mark.parametrize("modifier", [SpinQuantModifier])
@pytest.mark.parametrize("modifier", [SpinQuantModifier, QuIPModifier])
def test_reload(modifier):
instance = modifier(transform_type="hadamard")
dump = instance.model_dump()
Expand Down
Loading