Skip to content
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
80c92da
add auto-round
yiliu30 Nov 3, 2025
75f7efd
Merge branch 'main' into up-ar
yiliu30 Nov 3, 2025
3266b79
add auto-round modifier
yiliu30 Nov 3, 2025
9c537cc
refine code
yiliu30 Nov 3, 2025
bebe0fa
disbale qac for auto-round
yiliu30 Nov 3, 2025
dfb0ff8
clean code
yiliu30 Nov 3, 2025
513972c
add compile after disable qac
yiliu30 Nov 3, 2025
2291cc4
add iters and clean code
yiliu30 Nov 3, 2025
4028853
clean code
yiliu30 Nov 3, 2025
97ff9e0
add example
yiliu30 Nov 3, 2025
cb7a5b4
refine docs
yiliu30 Nov 3, 2025
5a7500e
refine example
yiliu30 Nov 3, 2025
d02a355
add init
yiliu30 Nov 3, 2025
cea9d2f
clean code
yiliu30 Nov 3, 2025
22be9b7
format
yiliu30 Nov 3, 2025
6cdb402
refactor
yiliu30 Nov 3, 2025
e2814eb
add ut
yiliu30 Nov 3, 2025
3e4a9fc
test llama 3
yiliu30 Nov 3, 2025
aa34b65
clean code
yiliu30 Nov 4, 2025
afe2ff7
parse layer-wise config
yiliu30 Nov 4, 2025
8e9eccc
format
yiliu30 Nov 4, 2025
81f76af
add docstring
yiliu30 Nov 4, 2025
afa6150
add ar
yiliu30 Nov 4, 2025
97217e7
update example
yiliu30 Nov 4, 2025
3dcb434
align api
yiliu30 Nov 5, 2025
aef7707
format
yiliu30 Nov 5, 2025
97e1ca2
clean code
yiliu30 Nov 5, 2025
c75c272
fix typo
yiliu30 Nov 5, 2025
3d8a0c8
small iters for ut
yiliu30 Nov 5, 2025
6729a75
format
yiliu30 Nov 5, 2025
bb4dbe8
refine comment
yiliu30 Nov 5, 2025
2adf0e7
replace papaer link
yiliu30 Nov 5, 2025
dd9bde9
correct comments
yiliu30 Nov 5, 2025
4980229
Merge branch 'main' into autoround-support
yiliu30 Nov 5, 2025
7d97255
update comments
yiliu30 Nov 5, 2025
f298e82
refine code
yiliu30 Nov 5, 2025
73c3571
add more checks
yiliu30 Nov 5, 2025
eb16397
update example
yiliu30 Nov 6, 2025
9cb1f06
move auto-round to modifier
yiliu30 Nov 6, 2025
76e0d21
apply untie
yiliu30 Nov 6, 2025
1cbe919
correct docstring
yiliu30 Nov 6, 2025
9fa5efb
enable ci
yiliu30 Nov 6, 2025
7937d80
revert import AutoRoundModifier into modfifier directly
yiliu30 Nov 6, 2025
e58b2bd
update
yiliu30 Nov 6, 2025
bd70ea6
Merge branch 'main' into autoround-support
yiliu30 Nov 6, 2025
6b236f6
merge main
yiliu30 Nov 7, 2025
4c94187
clean
yiliu30 Nov 7, 2025
7ea8442
fix
yiliu30 Nov 7, 2025
f52c0c0
refactor
yiliu30 Nov 7, 2025
4a9c4aa
format
yiliu30 Nov 7, 2025
0567df6
Update src/llmcompressor/modifiers/autoround/base.py
yiliu30 Nov 7, 2025
650a19c
refine docs
yiliu30 Nov 7, 2025
58e09bf
Merge branch 'autoround-support' of https://github.com/yiliu30/llm-co…
yiliu30 Nov 7, 2025
5cd35a6
fix import
yiliu30 Nov 8, 2025
678b123
Update src/llmcompressor/modifiers/autoround/base.py
yiliu30 Nov 8, 2025
a8c63d3
add qinput
yiliu30 Nov 10, 2025
38634dc
Merge branch 'autoround-support' of https://github.com/yiliu30/llm-co…
yiliu30 Nov 10, 2025
fbc047a
clean cache
yiliu30 Nov 10, 2025
96b6490
align api
yiliu30 Nov 10, 2025
d00d41b
fix
yiliu30 Nov 10, 2025
d4a8fb0
fix
yiliu30 Nov 10, 2025
487fcd2
update
yiliu30 Nov 10, 2025
baeea3f
Merge branch 'main' into autoround-support
yiliu30 Nov 11, 2025
3adc879
add requires_gpu for ut
yiliu30 Nov 12, 2025
ac10f7b
Merge branch 'main' into autoround-support
yiliu30 Nov 12, 2025
decb14f
Merge branch 'autoround-support' of https://github.com/yiliu30/llm-co…
yiliu30 Nov 12, 2025
f9dabc4
Merge branch 'main' into autoround-support
yiliu30 Nov 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions examples/quantization_w4a16/auto_round_llama3_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from auto_round.calib_dataset import get_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import AutoRoundModifier
from llmcompressor.utils import dispatch_for_generation

# Select model and load it.
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Select calibration dataset.
NUM_CALIBRATION_SAMPLES = 128
MAX_SEQUENCE_LENGTH = 2048
# Get aligned calibration dataset.

ds = get_dataset(
tokenizer=tokenizer,
seqlen=MAX_SEQUENCE_LENGTH,
nsamples=NUM_CALIBRATION_SAMPLES,
)


# Configure the quantization algorithm to run.
# * quantize the weights to 4 bit with AutoRound with a group size 128
recipe = AutoRoundModifier(
targets="Linear", scheme="W4A16", ignore=["lm_head"], iters=200
)


# Apply algorithms.
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
# disable shuffling to get slightly better mmlu score
shuffle_calibration_samples=False,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128-AutoRound"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def localversion_func(version: ScmVersion) -> str:
if BUILD_TYPE == "release"
else "compressed-tensors>=0.12.3a2"
),
# TODO: replace it with the release version
("auto_round @ git+https://github.com/intel/auto-round.git@llmc"),
],
extras_require={
"dev": [
Expand Down
1 change: 1 addition & 0 deletions src/llmcompressor/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from .gptq import *
from .quantization import *
from .autoround import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# ruff: noqa

from .base import *
323 changes: 323 additions & 0 deletions src/llmcompressor/modifiers/quantization/autoround/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
from typing import Dict, List, Optional, Tuple

import torch
from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationScheme,
QuantizationStrategy,
enable_quantization,
)
from compressed_tensors.utils import (
align_module_device,
match_named_modules,
update_offload_parameter,
)
from loguru import logger
from pydantic import PrivateAttr

from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.quantization.calibration import apply_calibration_status
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin

__all__ = ["AutoRoundModifier"]


def _is_decoding_layer(module, name):
return "decoderlayer" in module.__class__.__name__.lower()


class _LLModelWrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ModuleList()

def forward(self, *args, **kwargs):
for layer in self.layers:
res = layer(*args, **kwargs)
return res


class _PretrainModelWrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = _LLModelWrapper()

def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)


def _wrap_decoding_layer(layer: torch.nn.Module) -> _PretrainModelWrapper:
wrapped_model = _PretrainModelWrapper()
wrapped_model.model.layers.append(layer)
first_param = next(layer.parameters())
wrapped_model.dtype = first_param.dtype
return wrapped_model


class AutoRoundModifier(Modifier, QuantizationMixin):
"""
Implements the AutoRound algorithm from https://aclanthology.org/2024.findings-emnlp.662.pdf.
This modifier leverages signed gradient descent (SignSGD) optimizer and
block-wise loss to optimize rounding values and weight clipping in a few steps.
| Sample yaml:
| test_stage:
| modifiers:
| AutoRoundModifier:
| iters: 200
| config_groups:
| group_0:
| targets:
| - "Linear"
| input_activations: null
| output_activations: null
| weights:
| num_bits: 4
| type: "int"
| symmetric: true
| strategy: group
| group_size: 128
Lifecycle:
- on_initialize
- apply config to model
- on_start
- add input capture hooks to decoding layers
- on_sequential_epoch_end
- apply_autoround
- post_autoround_cleanup
- on_finalize
- remove_hooks()
- model.apply(freeze_module_quantization)
:param config_groups: dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param targets: list of layer names to quantize if a scheme is provided. Defaults
to Linear layers
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target in config_groups. Defaults to empty list.
:param scheme: a single quantization scheme to apply to the model. This is a
dictionary that supports all keys from QuantizationScheme except targets, which
will be set to the targets parameter set at the modifier level.
"""

# AutoRound modifier arguments
iters: Optional[int] = 200
enable_torch_compile: Optional[bool] = True

# private variables
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
_cur_layer_idx = PrivateAttr(default=0)
_all_module_input: Dict[str, List[Tuple]] = PrivateAttr(default_factory=dict)

def resolve_quantization_config(self) -> QuantizationConfig:
config = super().resolve_quantization_config()
return config

def _add_temporary_names(self, model: torch.nn.Module):
for name, mod in model.named_modules():
mod._tmp_name = name

def _remove_temporary_names(self, model: torch.nn.Module):
for _, mod in model.named_modules():
if hasattr(mod, "_tmp_name"):
del mod._tmp_name

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run the AutoRound algorithm on the current state
:param state: session state storing input model and calibration data
"""
# apply config to model and prepare calibration hooks
if QuantizationMixin.has_config(self):
QuantizationMixin.initialize_quantization(self, state.model)

# prepare module names
self._module_names = {
m: name
for name, m in match_named_modules(state.model, self.targets, self.ignore)
}
self._add_temporary_names(state.model)
# freeze all model parameters
for name, param in state.model.named_parameters():
param.requires_grad_(False)
return True

def start_calibration(self, model: torch.nn.Module):
"""
Register activation calibration hooks and enable quantization as we calibrate
:param model: model to prepare for calibration
"""

for _, module in match_named_modules(model, self.targets, self.ignore):
# Note: No need to register observers for auto-round
self._calibration_hooks |= self._initialize_hooks(module)
apply_calibration_status(module)

model.apply(enable_quantization) # quantize at the same time as calibrate

def input_capture_hook(self, module, *args, **kwargs):
if module._tmp_name not in self._all_module_input:
self._all_module_input[module._tmp_name] = []
self._all_module_input[module._tmp_name].append((args, kwargs))

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

# register quantization calibration hooks
# assume quantization has been initialized by this modifier or one before it
self.start_calibration(state.model)
for name, module in state.model.named_modules():
if _is_decoding_layer(module, name):
# register input capture hook for decoding layers
self.register_hook(
module, self.input_capture_hook, "forward_pre", with_kwargs=True
)

def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.CALIBRATION_EPOCH_START:
if not self.started_:
self.on_start(state, None)

if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
self.apply_autoround(state)
self.post_autoround_cleanup()

if event.type_ == EventType.CALIBRATION_EPOCH_END:
if not self.ended_:
self.on_end(state, None)

def _mapping_config_to_autoround(self):
from auto_round.schemes import QuantizationScheme as ARQuantizationScheme

resolved_config = self.resolved_config
quant_scheme = None
# TODO: release below constraint in later PRs
assert len(resolved_config.config_groups) == 1, (
"AutoRoundModifier only supports one quantization scheme for now, "
f"got {len(resolved_config.config_groups)}"
)

for scheme in resolved_config.config_groups.values():
assert isinstance(
scheme, QuantizationScheme
), f"Expected QuantizationScheme, got {type(scheme)}"
quant_scheme = scheme
weight_args = quant_scheme.weights
assert weight_args.strategy == QuantizationStrategy.GROUP, (
"Only group-wise quantization is supported in AutoRoundModifier for now, "
f"got {weight_args.strategy}"
)
assert quant_scheme.input_activations is None, (
"Input activation quantization is not supported in AutoRoundModifier, "
f"got {quant_scheme.input_activations}"
)
assert quant_scheme.output_activations is None, (
"Output activation quantization is not supported in AutoRoundModifier, "
f"got {quant_scheme.output_activations}"
)
ar_quant_scheme = ARQuantizationScheme(
bits=weight_args.num_bits,
sym=weight_args.symmetric,
group_size=weight_args.group_size,
data_type=weight_args.type,
act_bits=16,
)
return ar_quant_scheme

def apply_autoround(self, state):
"""
Applies AutoRound quantization tuning on the current decoding layer.
The tuning logic is as follows:
for iter in range(iters):
quant_output = forward(layer, cached_inputs)
loss = mse_loss(quant_output, original_output)
loss.backward()
optimizer.step()
if loss < best_loss:
best_params = save_params(layer)
This method retrieves the current decoding layer, wraps it for
compatibility with AutoRound, and performs iterative optimization
to minimize the quantization error. The best parameters are tracked
and applied to the layer after tuning.
For more details, please refer to the AutoRound repository:
https://github.com/intel/auto-round/
"""
cur_layer_idx = self._cur_layer_idx
logger.info("Applying AutoRound to layer index: {}", cur_layer_idx)
self._cur_layer_idx += 1
if cur_layer_idx >= len(state.model.model.layers):
# skip the lm_head layer
return
decoding_layer = state.model.model.layers[cur_layer_idx]

wrapped_model = _wrap_decoding_layer(decoding_layer)

with torch.enable_grad(), align_module_device(decoding_layer):
import auto_round

parsed_scheme = self._mapping_config_to_autoround()
ar = auto_round.AutoRound(
model=wrapped_model,
tokenizer="",
scheme=parsed_scheme,
iters=self.iters,
enable_quanted_input=False,
enable_torch_compile=self.enable_torch_compile,
)
# TODO: configure layer-wise config based on self.resolved_config
ar.configure_layer_config()
first_param = next(decoding_layer.parameters())
device = first_param.device
input_name = f"model.layers.{cur_layer_idx}"
cur_inputs = self._all_module_input[input_name]
decoding_layer.tuning_device = device

ar.quantize_block(
block=decoding_layer,
inputs=cur_inputs,
normalize_inputs=True,
device=device,
# Leave offload for LLMC
auto_offload=False,
)
# Update offload parameters and remove temporary attributes
for _, module in decoding_layer.named_modules():
if hasattr(module, "weight_scale") and hasattr(
module, "weight_zero_point"
):
# Note: The model's weight is already q-dq in-place by auto-round.
weight_scale = module.scale
del module.scale
# TODO: update zero_point after supporting asymmetric quantization
update_offload_parameter(module, "weight_scale", weight_scale)
decoding_layer.eval()

def post_autoround_cleanup(self):
self._all_module_input.clear()

def on_end(self, state: State, event: Event, **kwargs):
"""
Finish calibrating by removing observers and calibration hooks
"""
self.ended_ = True
QuantizationMixin.end_calibration(self, state.model)
self._remove_temporary_names(state.model)
self.remove_hooks()

def on_finalize(self, state: State, **kwargs) -> bool:
"""
disable the quantization observers used by the AutoRound algorithm
:param state: session state storing input model and calibration data
"""
if not self.ended_:
self.on_end(state, None)

return True
Loading