From 7b67045b9d5f2f32b28b629bad440ec91616b74c Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 25 Jun 2025 22:30:46 -0400 Subject: [PATCH 01/10] quant llama 70b Signed-off-by: yiliu30 --- examples/quantization_w4a4_fp4/llama3_example.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/quantization_w4a4_fp4/llama3_example.py b/examples/quantization_w4a4_fp4/llama3_example.py index 95d01657b9..21fa02046e 100644 --- a/examples/quantization_w4a4_fp4/llama3_example.py +++ b/examples/quantization_w4a4_fp4/llama3_example.py @@ -6,6 +6,11 @@ from llmcompressor.utils import dispatch_for_generation MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" +MODEL_ID = "/data5/yliu7/HF_HOME/meta-llama/Llama-3.2-1B-Instruct/" +MODEL_ID = "meta-llama/Llama-3.3-70B-Instruct" +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4" +SAVE_DIT = f"/data5/yliu7/HF_HOME/{SAVE_DIR}" +print(f"Saving to {SAVE_DIT}") # Load model. model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") @@ -76,6 +81,6 @@ def tokenize(sample): # Save to disk in compressed-tensors format. -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4" + model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) From 23087ce304ce029fe8728e5e0e1673ca5eec8c38 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 25 Jun 2025 22:32:15 -0400 Subject: [PATCH 02/10] add deepseek w4a4 nvfp4 Signed-off-by: yiliu30 --- .../quantizing_moe/deepseek_moe_w4a4_nvfp4.py | 135 ++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 examples/quantizing_moe/deepseek_moe_w4a4_nvfp4.py diff --git a/examples/quantizing_moe/deepseek_moe_w4a4_nvfp4.py b/examples/quantizing_moe/deepseek_moe_w4a4_nvfp4.py new file mode 100644 index 0000000000..630d58ddc9 --- /dev/null +++ b/examples/quantizing_moe/deepseek_moe_w4a4_nvfp4.py @@ -0,0 +1,135 @@ +import torch +from datasets import load_dataset +from packaging.version import Version +from transformers import AutoModelForCausalLM, AutoTokenizer, __version__ + +from llmcompressor import oneshot +from llmcompressor.utils import dispatch_for_generation + +# NOTE: transformers 4.49.0 has an attribute error with DeepSeek. +# Please consider either downgrading your transformers version to a +# previous version or upgrading to a version where this bug is fixed + +# select a Mixture of Experts model for quantization +MODEL_ID = "deepseek-ai/DeepSeek-V2.5" +MODEL_ID = "/data0/deepseek-ai/DeepSeek-V2-Lite" +MODEL_ID = "/data0/deepseek-ai/DeepSeek-R1" +MODEL_ID = "/data1/DeepSeek-R1-bf16" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True, + device_map="auto" +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" +NUM_CALIBRATION_SAMPLES = 2 +MAX_SEQUENCE_LENGTH = 2048 + + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# define a llmcompressor recipe for W416 quantization +# since the MoE gate layers are sensitive to quantization, we add them to the ignore +# list so they remain at full precision +# recipe = "deepseek_recipe_w4a16.yaml" + +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.utils import dispatch_for_generation + +recipe = QuantizationModifier(targets="Linear", scheme="NVFP4", ignore=["lm_head"] +) + +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + save_compressed=True, + trust_remote_code_model=True, +) + +# Confirm generations of the quantized model look sane. +# Generation is broken for deepseek models when using the latest transformers package +if Version(__version__) < Version("4.48"): + print("========== SAMPLE GENERATION ==============") + dispatch_for_generation(model) + input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") + output = model.generate(input_ids, max_new_tokens=20) + print(tokenizer.decode(output[0])) + print("==========================================") +else: + print( + "WARNING: cannot perform sample generation of " + "deepseek models with transformers >= 4.48" + ) + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) + + +# # Run the model on vLLM +# try: +# from vllm import LLM, SamplingParams + +# vllm_installed = True +# except ImportError: +# vllm_installed = False + +# if vllm_installed: +# print("vLLM installed, running using vLLM") +# sampling_params = SamplingParams(temperature=0.80, top_p=0.95) +# llm = LLM( +# model=SAVE_DIR, +# tensor_parallel_size=2, +# trust_remote_code=True, +# max_model_len=1042, +# dtype=torch.half, +# ) +# prompts = [ +# "The capital of France is", +# "The president of the US is", +# "My name is", +# ] + +# outputs = llm.generate(prompts, sampling_params) +# print("================= vLLM GENERATION ======================") +# for output in outputs: +# assert output +# prompt = output.prompt +# generated_text = output.outputs[0].text +# print("PROMPT", prompt) +# print("GENERATED TEXT", generated_text) From 55dafc4f284fac2a93d12c2021fedbc67124a6a2 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 27 Jun 2025 01:40:42 -0400 Subject: [PATCH 03/10] add mxfp4 support Signed-off-by: yiliu30 --- examples/quantization_w4a4_fp4/llama3_example.py | 11 ++++++++--- .../modifiers/quantization/calibration.py | 10 +++++++--- src/llmcompressor/modifiers/utils/helpers.py | 3 +++ 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/examples/quantization_w4a4_fp4/llama3_example.py b/examples/quantization_w4a4_fp4/llama3_example.py index 21fa02046e..08cdd2939f 100644 --- a/examples/quantization_w4a4_fp4/llama3_example.py +++ b/examples/quantization_w4a4_fp4/llama3_example.py @@ -8,11 +8,15 @@ MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" MODEL_ID = "/data5/yliu7/HF_HOME/meta-llama/Llama-3.2-1B-Instruct/" MODEL_ID = "meta-llama/Llama-3.3-70B-Instruct" -SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4" +scheme_name = "NVFP4" +scheme_name = "MXFP4" + +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + f"-{scheme_name}" SAVE_DIT = f"/data5/yliu7/HF_HOME/{SAVE_DIR}" print(f"Saving to {SAVE_DIT}") # Load model. + model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) @@ -22,7 +26,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 20 +NUM_CALIBRATION_SAMPLES = 4 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -60,7 +64,8 @@ def tokenize(sample): # * quantize the weights to fp4 with per group 16 via ptq # * calibrate a global_scale for activations, which will be used to # quantize activations to fp4 on the fly -recipe = QuantizationModifier(targets="Linear", scheme="NVFP4", ignore=["lm_head"]) + +recipe = QuantizationModifier(targets="Linear", scheme=scheme_name, ignore=["lm_head"]) # Apply quantization. oneshot( diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 63e1c2a24a..97182fb116 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -9,11 +9,11 @@ QuantizationStrategy, ) from compressed_tensors.quantization.lifecycle.forward import forward_quantize -from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme +from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme, is_mxfp4 from compressed_tensors.utils import align_module_device, update_parameter_data from loguru import logger from torch.nn import Module - +from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme, is_mxfp4 from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache from llmcompressor.observers import Observer from llmcompressor.utils.helpers import getattr_chain @@ -122,6 +122,10 @@ def update_weight_global_scale(module: Module): != QuantizationStrategy.TENSOR_GROUP ): return + weight_quant_args = getattr_chain(module, "quantization_scheme.weights") + if is_mxfp4(quantization_args=weight_quant_args): + # mxfp4 does not use global scale + return call_observer( module, @@ -178,7 +182,7 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): if quantization_args is not None: if quantization_args.dynamic in (True, DynamicType.LOCAL): calculate_qparams = False - if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: + if is_fp4(quantization_args=quantization_args): calculate_gparam = True call_observer( diff --git a/src/llmcompressor/modifiers/utils/helpers.py b/src/llmcompressor/modifiers/utils/helpers.py index a10af31567..ce62cecb7d 100644 --- a/src/llmcompressor/modifiers/utils/helpers.py +++ b/src/llmcompressor/modifiers/utils/helpers.py @@ -4,6 +4,7 @@ from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.utils import align_modules, update_parameter_data from torch.nn import Linear, Module +from compressed_tensors.quantization.utils import is_fp4, is_mxfp4 __all__ = ["update_fused_layer_weight_global_scales"] @@ -49,6 +50,8 @@ def _valid_tensor_group_quant(layer_list: List[Linear]): if weight_quant_args.strategy != QuantizationStrategy.TENSOR_GROUP: return False + if not is_fp4(quantization_args=weight_quant_args): + return False return True if _is_attention_module(submodule): From 12ab1b16a5fa891b6da4740f3539838339ca1b4b Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 27 Jun 2025 04:42:25 -0400 Subject: [PATCH 04/10] fix typo Signed-off-by: yiliu30 --- examples/quantization_w4a4_fp4/llama3_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/quantization_w4a4_fp4/llama3_example.py b/examples/quantization_w4a4_fp4/llama3_example.py index 08cdd2939f..0f0230eaf6 100644 --- a/examples/quantization_w4a4_fp4/llama3_example.py +++ b/examples/quantization_w4a4_fp4/llama3_example.py @@ -12,8 +12,8 @@ scheme_name = "MXFP4" SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + f"-{scheme_name}" -SAVE_DIT = f"/data5/yliu7/HF_HOME/{SAVE_DIR}" -print(f"Saving to {SAVE_DIT}") +SAVE_DIR = f"/data5/yliu7/HF_HOME/{SAVE_DIR}" +print(f"Saving to {SAVE_DIR}") # Load model. From a278a815730bbdce0f0f9779cf7cc603bb0a7949 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 27 Jun 2025 06:42:15 -0400 Subject: [PATCH 05/10] format Signed-off-by: yiliu30 --- examples/quantizing_moe/deepseek_moe_w4a4_nvfp4.py | 6 ++---- src/llmcompressor/modifiers/quantization/calibration.py | 8 ++++++-- src/llmcompressor/modifiers/utils/helpers.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/quantizing_moe/deepseek_moe_w4a4_nvfp4.py b/examples/quantizing_moe/deepseek_moe_w4a4_nvfp4.py index 630d58ddc9..d64f3a9772 100644 --- a/examples/quantizing_moe/deepseek_moe_w4a4_nvfp4.py +++ b/examples/quantizing_moe/deepseek_moe_w4a4_nvfp4.py @@ -17,8 +17,7 @@ MODEL_ID = "/data1/DeepSeek-R1-bf16" model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True, - device_map="auto" + MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) @@ -67,8 +66,7 @@ def tokenize(sample): from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.utils import dispatch_for_generation -recipe = QuantizationModifier(targets="Linear", scheme="NVFP4", ignore=["lm_head"] -) +recipe = QuantizationModifier(targets="Linear", scheme="NVFP4", ignore=["lm_head"]) oneshot( model=model, diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 97182fb116..c404ceb4d2 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -9,11 +9,15 @@ QuantizationStrategy, ) from compressed_tensors.quantization.lifecycle.forward import forward_quantize -from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme, is_mxfp4 +from compressed_tensors.quantization.utils import ( + is_fp4, + is_kv_cache_quant_scheme, + is_mxfp4, +) from compressed_tensors.utils import align_module_device, update_parameter_data from loguru import logger from torch.nn import Module -from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme, is_mxfp4 + from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache from llmcompressor.observers import Observer from llmcompressor.utils.helpers import getattr_chain diff --git a/src/llmcompressor/modifiers/utils/helpers.py b/src/llmcompressor/modifiers/utils/helpers.py index ce62cecb7d..e489a1ed12 100644 --- a/src/llmcompressor/modifiers/utils/helpers.py +++ b/src/llmcompressor/modifiers/utils/helpers.py @@ -2,9 +2,9 @@ import torch from compressed_tensors.quantization import QuantizationStrategy +from compressed_tensors.quantization.utils import is_fp4, is_mxfp4 from compressed_tensors.utils import align_modules, update_parameter_data from torch.nn import Linear, Module -from compressed_tensors.quantization.utils import is_fp4, is_mxfp4 __all__ = ["update_fused_layer_weight_global_scales"] From a268d2ee26554ca684e7d4d6b89f262e146a36fc Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 27 Jun 2025 06:41:26 -0400 Subject: [PATCH 06/10] mxfp8 support Signed-off-by: yiliu30 --- examples/quantization_w4a4_fp4/llama3_example.py | 6 ++++-- src/llmcompressor/modifiers/quantization/calibration.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/quantization_w4a4_fp4/llama3_example.py b/examples/quantization_w4a4_fp4/llama3_example.py index 0f0230eaf6..92bc5dd8f8 100644 --- a/examples/quantization_w4a4_fp4/llama3_example.py +++ b/examples/quantization_w4a4_fp4/llama3_example.py @@ -7,9 +7,11 @@ MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" MODEL_ID = "/data5/yliu7/HF_HOME/meta-llama/Llama-3.2-1B-Instruct/" -MODEL_ID = "meta-llama/Llama-3.3-70B-Instruct" +# MODEL_ID = "meta-llama/Llama-3.3-70B-Instruct" scheme_name = "NVFP4" scheme_name = "MXFP4" +scheme_name = "MXFP8" +# scheme_name = "FP8" SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + f"-{scheme_name}" SAVE_DIR = f"/data5/yliu7/HF_HOME/{SAVE_DIR}" @@ -80,7 +82,7 @@ def tokenize(sample): print("========== SAMPLE GENERATION ==============") dispatch_for_generation(model) input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") -output = model.generate(input_ids, max_new_tokens=100) +output = model.generate(input_ids, max_new_tokens=20) print(tokenizer.decode(output[0])) print("==========================================\n\n") diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index d6f8f0a018..a3e154985b 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -17,7 +17,7 @@ from compressed_tensors.utils import align_module_device, update_parameter_data from loguru import logger from torch.nn import Module - +from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme, is_mxfp4, is_mx from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache from llmcompressor.observers import Observer from llmcompressor.utils.helpers import getattr_chain @@ -142,7 +142,7 @@ def update_weight_global_scale(module: Module): ): return weight_quant_args = getattr_chain(module, "quantization_scheme.weights") - if is_mxfp4(quantization_args=weight_quant_args): + if is_mx(quantization_args=weight_quant_args): # mxfp4 does not use global scale return From 7c54f6a14cacfbe9d2ce032918c4bf39e894aa26 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 27 Jun 2025 06:45:50 -0400 Subject: [PATCH 07/10] format code Signed-off-by: yiliu30 --- src/llmcompressor/modifiers/quantization/calibration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index a3e154985b..6c19d03284 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -12,12 +12,13 @@ from compressed_tensors.quantization.utils import ( is_fp4, is_kv_cache_quant_scheme, + is_mx, is_mxfp4, ) from compressed_tensors.utils import align_module_device, update_parameter_data from loguru import logger from torch.nn import Module -from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme, is_mxfp4, is_mx + from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache from llmcompressor.observers import Observer from llmcompressor.utils.helpers import getattr_chain From e7452e26609568d179f06855aa40a5a9a9ffa4ef Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Fri, 27 Jun 2025 18:57:19 +0800 Subject: [PATCH 08/10] Update src/llmcompressor/modifiers/quantization/calibration.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llmcompressor/modifiers/quantization/calibration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 6c19d03284..bb66b5c71e 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -144,7 +144,7 @@ def update_weight_global_scale(module: Module): return weight_quant_args = getattr_chain(module, "quantization_scheme.weights") if is_mx(quantization_args=weight_quant_args): - # mxfp4 does not use global scale +# MX schemes do not use global scale return call_observer( From f0565585a8daabb645c98a56c25b422972e1352b Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 2 Jul 2025 09:24:28 -0400 Subject: [PATCH 09/10] fix quant format Signed-off-by: yiliu30 --- examples/quantization_w4a4_fp4/llama3_example.py | 2 +- .../transformers/compression/quantization_format.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/quantization_w4a4_fp4/llama3_example.py b/examples/quantization_w4a4_fp4/llama3_example.py index 92bc5dd8f8..8bf5bbf11a 100644 --- a/examples/quantization_w4a4_fp4/llama3_example.py +++ b/examples/quantization_w4a4_fp4/llama3_example.py @@ -10,7 +10,7 @@ # MODEL_ID = "meta-llama/Llama-3.3-70B-Instruct" scheme_name = "NVFP4" scheme_name = "MXFP4" -scheme_name = "MXFP8" +# scheme_name = "MXFP8" # scheme_name = "FP8" SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + f"-{scheme_name}" diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index 59d29bae48..598744e448 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -60,12 +60,14 @@ def infer_quantization_format( SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR ) is_weight_only = len(input_args) == 0 and len(weight_args) > 0 - if ( weight_args[0].num_bits == 4 and weight_args[0].type == QuantizationType.FLOAT.value ): - return CompressionFormat.nvfp4_pack_quantized + if weight_args[0].is_mx: + return CompressionFormat.mxfp4_pack_quantized + else: + return CompressionFormat.nvfp4_pack_quantized if is_weight_only: # w4a16 and w8a16 is_valid_pack = all( From cf9843b582ec69ccd32d77ada641483ecfdc8b60 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Wed, 3 Sep 2025 05:47:17 -0400 Subject: [PATCH 10/10] tmp fix Signed-off-by: yiliu30 --- .../quantization_w4a4_fp4/llama3_example.py | 2 +- .../compression/quantization_format.py | 114 +++++++++++------- 2 files changed, 71 insertions(+), 45 deletions(-) diff --git a/examples/quantization_w4a4_fp4/llama3_example.py b/examples/quantization_w4a4_fp4/llama3_example.py index b1cb51e4e3..96ab5d49d9 100644 --- a/examples/quantization_w4a4_fp4/llama3_example.py +++ b/examples/quantization_w4a4_fp4/llama3_example.py @@ -85,7 +85,7 @@ def tokenize(sample): input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to( model.device ) -output = model.generate(input_ids, max_new_tokens=100) +output = model.generate(input_ids, max_new_tokens=10) print(tokenizer.decode(output[0])) print("==========================================\n\n") diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index f635c4aac3..650d7c4de1 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -24,6 +24,8 @@ def _get_quant_compression_format( is_weight_only = weight_args is not None and input_args is None if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value: + if weight_args.is_mx: + return CompressionFormat.mxfp4_pack_quantized return CompressionFormat.nvfp4_pack_quantized if is_weight_only: # w4a16 and w8a16 @@ -55,6 +57,30 @@ def _get_quant_compression_format( return CompressionFormat.naive_quantized +def _get_unique_quant_args(model): + """ + Gets a list of all the unique quantization settings present in model + """ + from compressed_tensors.quantization.utils import ( + is_model_quantized, + is_module_quantized, + iter_named_leaf_modules, + ) + quant_info_weight = [] + quant_info_inputs = [] + for _, submodule in iter_named_leaf_modules(model): + if is_module_quantized(submodule): + weight_scheme = submodule.quantization_scheme.weights + input_scheme = submodule.quantization_scheme.input_activations + if weight_scheme is not None: + if weight_scheme not in quant_info_weight: + quant_info_weight.append(weight_scheme) + if input_scheme is not None: + if input_scheme not in quant_info_inputs: + quant_info_inputs.append(input_scheme) + + return quant_info_weight, quant_info_inputs + def infer_and_set_per_module_quantization_format( model, quantization_format: Optional[str] = None, @@ -79,50 +105,50 @@ def infer_and_set_per_module_quantization_format( if not save_compressed: return None - if save_compressed: - weight_args, input_args = _get_unique_quant_args(model) - is_24_structure = ( - SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR - ) - is_weight_only = len(input_args) == 0 and len(weight_args) > 0 - if ( - weight_args[0].num_bits == 4 - and weight_args[0].type == QuantizationType.FLOAT.value - ): - if weight_args[0].is_mx: - return CompressionFormat.mxfp4_pack_quantized - else: - return CompressionFormat.nvfp4_pack_quantized - - if is_weight_only: # w4a16 and w8a16 - is_valid_pack = all( - weight_arg.num_bits in [4, 8] - and weight_arg.type == QuantizationType.INT.value - for weight_arg in weight_args - ) - if not is_valid_pack: # packing only valid for int4 and int 8 - return CompressionFormat.naive_quantized - if is_24_structure: - for arg in weight_args: - if ( - arg.strategy is not QuantizationStrategy.CHANNEL.value - and arg.strategy is not QuantizationStrategy.GROUP.value - ): - # marlin24 kernel only applicable for channel/group quantization - return CompressionFormat.pack_quantized - return CompressionFormat.marlin_24 - return CompressionFormat.pack_quantized - else: # w8a8 float and int - if len(weight_args) == 1: - if ( - weight_args[0].type == QuantizationType.FLOAT.value - and weight_args[0].num_bits == 8 - ): - return CompressionFormat.float_quantized - if weight_args[0].type == QuantizationType.INT.value: - return CompressionFormat.int_quantized - - return CompressionFormat.naive_quantized + # if save_compressed: + # weight_args, input_args = _get_unique_quant_args(model) + # is_24_structure = ( + # SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR + # ) + # is_weight_only = len(input_args) == 0 and len(weight_args) > 0 + # if ( + # weight_args[0].num_bits == 4 + # and weight_args[0].type == QuantizationType.FLOAT.value + # ): + # if weight_args[0].is_mx: + # return CompressionFormat.mxfp4_pack_quantized + # else: + # return CompressionFormat.nvfp4_pack_quantized + + # if is_weight_only: # w4a16 and w8a16 + # is_valid_pack = all( + # weight_arg.num_bits in [4, 8] + # and weight_arg.type == QuantizationType.INT.value + # for weight_arg in weight_args + # ) + # if not is_valid_pack: # packing only valid for int4 and int 8 + # return CompressionFormat.naive_quantized + # if is_24_structure: + # for arg in weight_args: + # if ( + # arg.strategy is not QuantizationStrategy.CHANNEL.value + # and arg.strategy is not QuantizationStrategy.GROUP.value + # ): + # # marlin24 kernel only applicable for channel/group quantization + # return CompressionFormat.pack_quantized + # return CompressionFormat.marlin_24 + # return CompressionFormat.pack_quantized + # else: # w8a8 float and int + # if len(weight_args) == 1: + # if ( + # weight_args[0].type == QuantizationType.FLOAT.value + # and weight_args[0].num_bits == 8 + # ): + # return CompressionFormat.float_quantized + # if weight_args[0].type == QuantizationType.INT.value: + # return CompressionFormat.int_quantized + + # return CompressionFormat.naive_quantized if quantization_format: