Skip to content

Conversation

@yiliu30
Copy link

@yiliu30 yiliu30 commented Nov 5, 2025

Resolve #1968

Highlights

  • Introduced AutoRoundModifier to enable AutoRound quantization for wNa16.
  • Added an end-to-end example and unit tests.
  • Verified functionality with local accuracy tests (GSM8K with a limit of 1000, the results may fluctuate due to non-determinism.)
- LLMC-AutoRound
vllm (pretrained=/storage/yiliu7/Meta-Llama-3-8B-Instruct-W4A16-G128-disbale-shuffule,tensor_parallel_size=1,max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False), gen_kwargs: (None), limit: 1000.0, num_fewshot: None, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.737|±  |0.0139|
|     |       |strict-match    |     5|exact_match||0.736|±  |0.0139|

- AutoRound result as ref
vllm (pretrained=/storage/yiliu7/meta-llama/Meta-Llama-3-8B-Instruct-ar/Meta-Llama-3-8B-Instruct-w4g128/,tensor_parallel_size=1,max_model_len=8192,max_num_batched_tokens=32768,max_num_seqs=128,add_bos_token=True,gpu_memory_utilization=0.8,dtype=bfloat16,max_gen_toks=2048,enable_prefix_caching=False), gen_kwargs: (None), limit: 1000.0, num_fewshot: None, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.739|±  |0.0139|
|     |       |strict-match    |     5|exact_match||0.740|±  |0.0139|

Attached eval cmd FYI.

Next stage (in later PRs)

  • Extend support for additional data types.
  • Add group-wise quantization recipes mapping between LLMC and AutoRound.
  • Add end-to-end tests.

cc @hshen14 @thuang6 @wenhuach21

Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Copy link
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. There's a few more things I'd like to point out, will give a full review soon.

if cur_layer_idx >= len(state.model.model.layers):
# skip the lm_head layer
return
decoding_layer = state.model.model.layers[cur_layer_idx]
Copy link
Collaborator

@kylesayrs kylesayrs Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sequential pipeline is not guaranteed to break a model into decoder layers. If a user specifies sequential_targets="Linear", then each SEQUENTIAL_EPOCH_END will trigger on each linear layer of the model.

One way to generalize this would be to have the sequential pipeline return the modules that were in the sequential layer

class Subgraph:
    def get_modules(self, model: torch.nn.Module, recurse: bool = False) -> Set[torch.nn.Module]:
        nodes = self.graph.find_nodes(op="call_module")
        modules = set(model.get_submodule(node.target) for node in nodes)
        if recurse:
            modules = set(module.modules() for module in modules)

        return modules
class SequentialPipeline:
...
    for subgraph in subgraphs:
         LifecycleCallbacks.sequential_epoch_end(subgraph)
def apply_autoround(self, state, subgraph):
    decoding_layer = torch.nn.ModuleList(list(subgraph.modules()))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI these changes are implemented here, and this PR can potentially rebase on them.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed explanation — this approach is definitely more robust. I’ve tested it locally, and it works well.

Will #1998 be merged soon? If so, I’d prefer to rebase on main and update my PR accordingly to avoid introducing too much code here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have rebased the main branch and updated it accordingly.

kylesayrs added a commit that referenced this pull request Nov 6, 2025
…ve `layer_sequential` pipeline (#1998)

## Purpose ##
* Enable better targeting of modules by modifiers such as
[AutoRound](#1994)
* Remove legacy pipeline (which is incompatible with this change)

## Changes ##
* Pass subgraph to `sequential_epoch_end`, allowing modifiers to view
all of the module that were called in the subgraph
* Implement `submodules` method on `Subgraph` which returns all the
modules called by this subgraph
* Remove `LayerSequentialPipeline`, which does not use the `Subgraph`
API and has been superseded by the sequential pipeline

---------

Signed-off-by: Kyle Sayers <[email protected]>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Dipika Sikka <[email protected]>
Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! This is really cool. I have a few comments/questions in an initial review

yiliu30 and others added 8 commits November 6, 2025 19:36
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Co-authored-by: Brian Dellabetta <[email protected]>
Signed-off-by: Yi Liu <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my comments! A few more small things:

if BUILD_TYPE == "release"
else "compressed-tensors>=0.12.3a2"
),
# TODO: replace it with the release version
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @yiliu30 , do you have an estimate for when the next version of autoround will release? Does it have the appropriate licensing to avoid issues like this?

yiliu30 and others added 2 commits November 7, 2025 19:55
Signed-off-by: yiliu30 <[email protected]>
Co-authored-by: Brian Dellabetta <[email protected]>
Signed-off-by: Yi Liu <[email protected]>
@yiliu30
Copy link
Author

yiliu30 commented Nov 8, 2025

Hi @yiliu30 , do you have an estimate for when the next version of autoround will release? Does it have the appropriate licensing to avoid issues like vllm-project/compressed-tensors#468?

Hi @brian-dellabetta , We're planning to release the next version within the next 1–2 weeks—hope that works for you!
As for AutoRound, it's licensed under Apache License 2.0, so I guess there shouldn't be any licensing concerns.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready When a PR is ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RFC: Add Intel AutoRound Quantization Algorithm Support

5 participants