diff --git a/docs/guides/compression_formats.md b/docs/guides/compression_formats.md index 5ef77e01e..80aae0025 100644 --- a/docs/guides/compression_formats.md +++ b/docs/guides/compression_formats.md @@ -14,6 +14,7 @@ sparsity type. For more details on the quantization schemes, see | W4A16 - float | None | nvfp4_pack_quantized | Dense | | W4A4 - float | None | nvfp4_pack_quantized | Dense | | W4A16 - int | None | pack_quantized | Dense | +| W4A8 - int | None | int4_quantized | Dense | | W8A16 - int | None | pack_quantized | Dense | | W8A16 - float | None | naive_quantized | Dense | | W8A8 - int | 2:4 | int_quantized | Sparse24 | diff --git a/examples/quantizing_moe/qwen_example_w4a8.py b/examples/quantizing_moe/qwen_example_w4a8.py new file mode 100644 index 000000000..dfa1e724d --- /dev/null +++ b/examples/quantizing_moe/qwen_example_w4a8.py @@ -0,0 +1,119 @@ +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier +from llmcompressor.utils import dispatch_for_generation + +# select a Mixture of Experts model for quantization +MODEL_ID = "Qwen/Qwen1.5-MoE-A2.7B-Chat" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# define a llmcompressor recipe for W416 quantization with a group size of 128 +# since the MoE gate layers are sensitive to quantization, we add them to the ignore +# list so they remain at full precision + +recipe = """ +quant_stage: + quant_modifiers: + QuantizationModifier: + ignore: ["lm_head", "re:.*mlp.gate$"] + config_groups: + group_0: + weights: + num_bits: 8 + type: float + strategy: block + dynamic: false + symmetric: true + block_structure: [128, 128] + input_activations: + num_bits: 8 + type: float + strategy: token + dynamic: true + symmetric: true + targets: ["re:.*self_attn.q_proj", "re:.*self_attn.k_proj", + "re:.*self_attn.v_proj", "re:.*self_attn.o_proj",] + group_1: + weights: + num_bits: 4 + type: int + strategy: tensor_group + dynamic: false + symmetric: true + group_size: 128 + input_activations: + num_bits: 8 + type: float + strategy: tensor + dynamic: true + symmetric: true + targets: [ "re:.*gate_proj", "re:.*up_proj", "re:.*down_proj"] +""" + +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + save_compressed=True, + trust_remote_code_model=True, +) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +sample = tokenizer("Hello my name is", return_tensors="pt") +sample = {key: value.to(model.device) for key, value in sample.items()} +output = model.generate(**sample, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================") + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-wInt4aFp8" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index f583533af..c9f184e2a 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -49,6 +49,8 @@ def _get_quant_compression_format( and weight_args.num_bits == 8 ): return CompressionFormat.float_quantized + if weight_args.type == QuantizationType.INT.value and weight_args.num_bits == 4 and weight_args.strategy is QuantizationStrategy.TENSOR_GROUP.value: + return CompressionFormat.int4_quantized if weight_args.type == QuantizationType.INT.value: return CompressionFormat.int_quantized