Skip to content

[MoE] Cleanup MoE examples #1576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions examples/quantizing_moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ pip install -e .
The provided example script demonstrates an end-to-end process for applying the quantization algorithm:

```bash
python3 mixtral_moe_w8a8_fp8.py
python3 mixtral_example.py
```

## Creating a Quantized MoE Model

This example leverages `llm-compressor` and `compressed-tensors` to create an FP8-quantized `Mixtral-8x7B-Instruct-v0.1` model. The model is calibrated and trained using the `open_platypus` dataset.
This example leverages `llm-compressor` and `compressed-tensors` to create an FP8-quantized `Mixtral-8x7B-Instruct-v0.1` model. The model is calibrated and trained using the `ultrachat_200k` dataset.

You can follow the detailed steps below or simply run the example script with:

```bash
python mixtral_moe_w8a8_fp8.py
python mixtral_example.py
```

### Step 1: Select a Model, Dataset, and Recipe
Expand Down Expand Up @@ -74,7 +74,7 @@ NOTE: Only per-tensor quantization is supported in vLLM as of now (`vllm==0.6.1`

The repository supports multiple quantization techniques configured via a recipe. Supported strategies include `tensor`, `group`, and `channel` quantization.

In the above example, FP8 per-tensor quantization is used as specified by the `FP8` scheme. For other preset schemes, refer to the [quantization schemes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py) in the `compressed-tensors` library.
In the above example, quantization is specified by the `W4A18` scheme. For other preset schemes, refer to the [quantization schemes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py) in the `compressed-tensors` library.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this using the FP8 scheme above?


A custom scheme can also be specified using `config_groups`:

Expand All @@ -84,18 +84,18 @@ A custom scheme can also be specified using `config_groups`:
from llmcompressor.modifiers.quantization.gptq import GPTQModifier

config_groups = {
"group_0": {
"targets": ["Linear"],
"input_activations": None,
"output_activations": None,
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": true,
"strategy": "group",
"group_size": 128,
}
}
"group_0": {
"targets": ["Linear"],
"input_activations": None,
"output_activations": None,
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": true,
"strategy": "group",
"group_size": 128,
}
}
}

recipe = GPTQModifier(config_groups=config_groups)
Expand Down
125 changes: 0 additions & 125 deletions examples/quantizing_moe/deepseek_moe_w4a16.py

This file was deleted.

8 changes: 0 additions & 8 deletions examples/quantizing_moe/deepseek_recipe_w4a16.yaml

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@
# previous version or upgrading to a version where this bug is fixed

# select a Mixture of Experts model for quantization
MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
MODEL_ID = "deepseek-ai/DeepSeek-V2.5"

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True
)
Comment on lines 17 to 19
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need device_map="auto" to make this huge model fit?. I think it would be nice to still keep around a small moe example

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
# its recommended to use more calibration samples for MoE models so each expert is hit
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 2048
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048


Expand Down Expand Up @@ -57,16 +56,12 @@ def tokenize(sample):

ds = ds.map(tokenize, remove_columns=ds.column_names)

# define a llmcompressor recipe for INT8 W8A8 quantization
# Configure the quantization algorithm to run.
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
# list so they remain at full precision
recipe = [
GPTQModifier(
targets="Linear",
scheme="W8A8",
ignore=["lm_head", "re:.*mlp.gate$"],
),
]
recipe = GPTQModifier(
targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"]
)

oneshot(
model=model,
Expand All @@ -82,12 +77,10 @@ def tokenize(sample):
if Version(__version__) < Version("4.48"):
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
SAMPLE_INPUT = ["I love quantization because"]
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
inputs = tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to(model.device)
output = model.generate(**inputs, max_length=50)
text_output = tokenizer.batch_decode(output)
print(text_output)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to("cuda") for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================")
else:
print(
Expand All @@ -96,6 +89,6 @@ def tokenize(sample):
)

# Save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8"
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
88 changes: 88 additions & 0 deletions examples/quantizing_moe/deepseekv3_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modeling import prepare_for_quantization
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.utils import dispatch_for_generation

# Select model and load it.
# For DeepSeekv3, we require a full precision model in order to properly calibrate
# `DeepSeek-V3-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16
model_id = "RedHatAI/DeepSeek-V3-BF16"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to fit this across GPUs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, #1263 landed

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = prepare_for_quantization(model)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm to run.
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
# list so they remain at full precision
recipe = GPTQModifier(
targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"]
)

# Apply algorithms.
# due to the large size of DeepSeekV3, we specify sequential targets such that
# only one MLP is loaded into GPU memory at a time
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"],
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
sample = tokenizer("Hello my name is", return_tensors="pt")
sample = {key: value.to("cuda") for key, value in sample.items()}
output = model.generate(**sample, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
Loading