Skip to content

Conversation

fynnsu
Copy link
Collaborator

@fynnsu fynnsu commented Aug 29, 2025

SUMMARY:
We're currently missing an abs() call when computing activation means. This results in negative mean values, which then interferes with the scales calculation, causing the scales to be set to nan.

As seen here, this abs() call existed in the AutoAWQ repo but was lost at some stage.

TEST PLAN:
After fix, applied AWQ to Qwen3-Coder-30B-A3B-Instruct and evaluated on humaneval and humaneval+

Algorithm humaneval@1 humaneval@10 humaneval+@1 humaneval+@10
base 0.93 0.939 0.887 0.898
GPTQ 0.927 0.945 0.885 0.905
RTN 0.924 0.939 0.87 0.885
AWQ 0.937 0.945 0.893 0.902
awq script
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier

MODEL_ID = "Qwen/Qwen3-Coder-30B-A3B-Instruct"

SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-sym-main-abs2-no-duo-scaling"


# Configure the quantization algorithm to run.
recipe = [
    AWQModifier(
        duo_scaling=False,
        ignore=[
            "lm_head",
            "re:.*mlp.gate$",
            "re:.*mlp.shared_expert_gate$",
            "re:visual.*",
        ],
        scheme="W4A16",
        targets=["Linear"],
    ),
]

# Select calibration dataset.
DATASET_ID = "codeparrot/self-instruct-starcoder"
DATASET_SPLIT = "curated"

# Select number of samples. 256 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 2048


def get_calib_dataset(tokenizer):
    from datasets import load_dataset

    ds = load_dataset(
        DATASET_ID,
        split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*10}]",
    )

    def preprocess(example):
        chat_messages = [
            {"role": "user", "content": example["instruction"].strip()},
            {"role": "assistant", "content": example["output"].strip()},
        ]
        tokenized_messages = tokenizer.apply_chat_template(
            chat_messages, tokenize=True
        )
        return {"input_ids": tokenized_messages}

    ds = (
        ds.shuffle(seed=42)
        .map(preprocess, remove_columns=ds.column_names)
        .select(range(NUM_CALIBRATION_SAMPLES))
    )

    return ds


if __name__ == "__main__":
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, torch_dtype="auto", trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

    ###
    ### Apply algorithms.
    ###
    oneshot(
        model=model,
        dataset=get_calib_dataset(tokenizer),
        recipe=recipe,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        num_calibration_samples=NUM_CALIBRATION_SAMPLES,
        log_dir=None,
        trust_remote_code_model=True,
    )

    # Save to disk compressed.
    model.save_pretrained(SAVE_DIR)
    tokenizer.save_pretrained(SAVE_DIR)
gptq script
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier

MODEL_ID = "Qwen/Qwen3-Coder-30B-A3B-Instruct"

SAVE_DIR = MODEL_ID.split("/")[-1] + "-gptq-sym"


# Configure the quantization algorithm to run.
recipe = [
    GPTQModifier(
        ignore=[
            "lm_head",
            "re:.*mlp.gate$",
            "re:.*mlp.shared_expert_gate$",
            "re:visual.*",
        ],
        scheme="W4A16",
        targets=["Linear"],
    ),
]

# Select calibration dataset.
DATASET_ID = "codeparrot/self-instruct-starcoder"
DATASET_SPLIT = "curated"

# Select number of samples. 256 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 2048


def get_calib_dataset(tokenizer):
    from datasets import load_dataset

    ds = load_dataset(
        DATASET_ID,
        split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*10}]",
    )

    def preprocess(example):
        chat_messages = [
            {"role": "user", "content": example["instruction"].strip()},
            {"role": "assistant", "content": example["output"].strip()},
        ]
        tokenized_messages = tokenizer.apply_chat_template(
            chat_messages, tokenize=True
        )
        return {"input_ids": tokenized_messages}

    ds = (
        ds.shuffle(seed=42)
        .map(preprocess, remove_columns=ds.column_names)
        .select(range(NUM_CALIBRATION_SAMPLES))
    )

    return ds


if __name__ == "__main__":
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, torch_dtype="auto", trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

    ###
    ### Apply algorithms.
    ###
    oneshot(
        model=model,
        dataset=get_calib_dataset(tokenizer),
        recipe=recipe,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        num_calibration_samples=NUM_CALIBRATION_SAMPLES,
        log_dir=None,
        trust_remote_code_model=True,
    )


    # Save to disk compressed.
    model.save_pretrained(SAVE_DIR)
    tokenizer.save_pretrained(SAVE_DIR)
rtn script
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier

MODEL_ID = "Qwen/Qwen3-Coder-30B-A3B-Instruct"

SAVE_DIR = MODEL_ID.split("/")[-1] + "-rtn-sym"


# Configure the quantization algorithm to run.
recipe = [
    QuantizationModifier(
        ignore=[
            "lm_head",
            "re:.*mlp.gate$",
            "re:.*mlp.shared_expert_gate$",
            "re:visual.*",
        ],
        scheme="W4A16",
        targets=["Linear"],
    ),
]

# Select calibration dataset.
DATASET_ID = "codeparrot/self-instruct-starcoder"
DATASET_SPLIT = "curated"

# Select number of samples. 256 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 2048


def get_calib_dataset(tokenizer):
    from datasets import load_dataset

    ds = load_dataset(
        DATASET_ID,
        split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*10}]",
    )

    def preprocess(example):
        chat_messages = [
            {"role": "user", "content": example["instruction"].strip()},
            {"role": "assistant", "content": example["output"].strip()},
        ]
        tokenized_messages = tokenizer.apply_chat_template(
            chat_messages, tokenize=True
        )
        return {"input_ids": tokenized_messages}

    ds = (
        ds.shuffle(seed=42)
        .map(preprocess, remove_columns=ds.column_names)
        .select(range(NUM_CALIBRATION_SAMPLES))
    )

    return ds


if __name__ == "__main__":
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, torch_dtype="auto", trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

    ###
    ### Apply algorithms.
    ###
    oneshot(
        model=model,
        dataset=get_calib_dataset(tokenizer),
        recipe=recipe,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        num_calibration_samples=NUM_CALIBRATION_SAMPLES,
        log_dir=None,
        trust_remote_code_model=True,
    )

    # Save to disk compressed.
    model.save_pretrained(SAVE_DIR)
    tokenizer.save_pretrained(SAVE_DIR)

Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @fynnsu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug in the AWQ quantization implementation that resulted in incorrect scale calculations due to negative activation values. By reintroducing a missing absolute value operation, the change ensures numerical stability and improves the overall accuracy and robustness of the quantization process, as validated by comprehensive evaluation tests.

Highlights

  • Corrected AWQ Activation Mean Calculation: Implemented abs() calls to ensure activation means are always positive, resolving an issue where negative means led to NaN values in AWQ scale calculations.
  • Improved Quantization Stability: This fix enhances the stability and reliability of the AWQ quantization process, as demonstrated by improved evaluation metrics for the Qwen3-Coder-30B-A3B-Instruct model.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly identifies and fixes a bug where negative activation values could lead to nan scales in the AWQ algorithm by adding an abs() call during the accumulation of activation means. This aligns with the original AutoAWQ implementation, and the test results confirm the fix is effective.

I've added one comment regarding a potentially redundant abs() call that could be removed to improve code clarity.

Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good after removing the redundant line

@fynnsu fynnsu force-pushed the awq_abs_fix branch 2 times, most recently from dcaf3c0 to 09b2eeb Compare September 2, 2025 14:31
Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

Copy link
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch, nice job

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready When a PR is ready for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants