Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/quantization_w4a4_fp4/llama3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
MODEL_ID = "/data5/yliu7/HF_HOME/meta-llama/Llama-3.2-1B-Instruct/"
# MODEL_ID = "meta-llama/Llama-3.3-70B-Instruct"
scheme_name = "NVFP4"
scheme_name = "MXFP4"

# scheme_name = "MXFP8"
# scheme_name = "FP8"

scheme_name = "NVFPP_B32"
scheme_name = "NVFPP_B16"
# scheme_name = "MXFP4"
# scheme_name = ""
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + f"-{scheme_name}"
SAVE_DIR = f"/data5/yliu7/HF_HOME/{SAVE_DIR}"
print(f"Saving to {SAVE_DIR}")
Expand Down Expand Up @@ -85,6 +88,7 @@ def tokenize(sample):
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
model.device
)
print(f"=========== Starting generation =================")
output = model.generate(input_ids, max_new_tokens=10)

print(tokenizer.decode(output[0]))
Expand Down
8 changes: 7 additions & 1 deletion src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
is_kv_cache_quant_scheme,
is_mx,
is_mxfp4,
use_global_scales
)
from compressed_tensors.utils import align_module_device, update_parameter_data
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
Expand Down Expand Up @@ -149,7 +150,10 @@ def update_weight_global_scale(module: Module):
return
weight_quant_args = getattr_chain(module, "quantization_scheme.weights")
if is_mx(quantization_args=weight_quant_args):
# MX schemes do not use global scale
# MX schemes do not use global scale
return
if not use_global_scales(quantization_args=weight_quant_args):
# global scales already in use
return

call_observer(
Expand Down Expand Up @@ -209,6 +213,8 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
calculate_qparams = False
if is_fp4(quantization_args=quantization_args):
calculate_gparam = True
if not use_global_scales(quantization_args=quantization_args):
calculate_gparam = False

call_observer(
module=module,
Expand Down
4 changes: 3 additions & 1 deletion src/llmcompressor/modifiers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.utils import is_fp4, is_mxfp4
from compressed_tensors.quantization.utils import is_fp4, is_mxfp4, use_global_scales
from compressed_tensors.utils import align_modules, update_parameter_data
from torch.nn import Linear, Module

Expand Down Expand Up @@ -52,6 +52,8 @@ def _valid_tensor_group_quant(layer_list: List[Linear]):
return False
if not is_fp4(quantization_args=weight_quant_args):
return False
if not use_global_scales(quantization_args=weight_quant_args):
return False
return True

if _is_attention_module(submodule):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from compressed_tensors.quantization.utils import is_module_quantized
from loguru import logger
from compressed_tensors.quantization.utils.helpers import is_nvfpp_b32, is_nvfpp_b16

__all__ = ["infer_and_set_per_module_quantization_format"]

Expand All @@ -26,6 +27,10 @@ def _get_quant_compression_format(
if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value:
if weight_args.is_mx:
return CompressionFormat.mxfp4_pack_quantized
if is_nvfpp_b32(weight_args):
return CompressionFormat.nvfpp_b32_pack_quantized
if is_nvfpp_b16(weight_args):
return CompressionFormat.nvfpp_b16_pack_quantized
return CompressionFormat.nvfp4_pack_quantized

if is_weight_only: # w4a16 and w8a16
Expand Down
12 changes: 4 additions & 8 deletions tests/llmcompressor/transformers/kv_cache/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,10 @@ def test_kv_cache_gptq_model_state_dict_attr(kv_cache_fixture, tmp_path):

output_dir, _ = next(kv_cache_fixture(recipe, tmp_path))

with init_empty_weights():
# TODO: There is a bug in `apply_quantization_config` which means that, if using
# CompressedLinears, the compression status is inferred to `compressed` and
# therefore the attention kvcache parameters never undergo initializations
model = AutoModelForCausalLM.from_pretrained(
output_dir,
quantization_config=CompressedTensorsConfig(run_compressed=False),
)
model = AutoModelForCausalLM.from_pretrained(
output_dir,
quantization_config=CompressedTensorsConfig(run_compressed=False),
)

counts = 0
for name, submodule in model.named_modules():
Expand Down
Loading