Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/guides/compression_formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ sparsity type. For more details on the quantization schemes, see
| W4A16 - float | None | nvfp4_pack_quantized | Dense |
| W4A4 - float | None | nvfp4_pack_quantized | Dense |
| W4A16 - int | None | pack_quantized | Dense |
| W4A8 - int | None | int4_quantized | Dense |
| W8A16 - int | None | pack_quantized | Dense |
| W8A16 - float | None | naive_quantized | Dense |
| W8A8 - int | 2:4 | int_quantized | Sparse24 |
Expand Down
119 changes: 119 additions & 0 deletions examples/quantizing_moe/qwen_example_w4a8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
from llmcompressor.utils import dispatch_for_generation

# select a Mixture of Experts model for quantization
MODEL_ID = "Qwen/Qwen1.5-MoE-A2.7B-Chat"

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048


# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

# define a llmcompressor recipe for W416 quantization with a group size of 128
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
# list so they remain at full precision

recipe = """
quant_stage:
quant_modifiers:
QuantizationModifier:
ignore: ["lm_head", "re:.*mlp.gate$"]
config_groups:
group_0:
weights:
num_bits: 8
type: float
strategy: block
dynamic: false
symmetric: true
block_structure: [128, 128]
input_activations:
num_bits: 8
type: float
strategy: token
dynamic: true
symmetric: true
targets: ["re:.*self_attn.q_proj", "re:.*self_attn.k_proj",
"re:.*self_attn.v_proj", "re:.*self_attn.o_proj",]
group_1:
weights:
num_bits: 4
type: int
strategy: tensor_group
dynamic: false
symmetric: true
group_size: 128
input_activations:
num_bits: 8
type: float
strategy: tensor
dynamic: true
symmetric: true
targets: [ "re:.*gate_proj", "re:.*up_proj", "re:.*down_proj"]
"""

oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
save_compressed=True,
trust_remote_code_model=True,
)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to(model.device) for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================")

# Save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-wInt4aFp8"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def _get_quant_compression_format(
and weight_args.num_bits == 8
):
return CompressionFormat.float_quantized
if weight_args.type == QuantizationType.INT.value and weight_args.num_bits == 4 and weight_args.strategy is QuantizationStrategy.TENSOR_GROUP.value:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For comparing string values, it's safer and more conventional to use the == operator instead of is. The is operator checks for object identity, which might work for interned strings but is not guaranteed across different Python implementations or versions. Using == ensures the comparison is always done by value.1

Suggested change
if weight_args.type == QuantizationType.INT.value and weight_args.num_bits == 4 and weight_args.strategy is QuantizationStrategy.TENSOR_GROUP.value:
if weight_args.type == QuantizationType.INT.value and weight_args.num_bits == 4 and weight_args.strategy == QuantizationStrategy.TENSOR_GROUP.value:

Rules References

Footnotes

  1. Use == for value equality and is for identity equality. For comparing string literals, == is preferred for robustness as string interning is an implementation detail.

return CompressionFormat.int4_quantized
if weight_args.type == QuantizationType.INT.value:
return CompressionFormat.int_quantized

Expand Down