-
-
Notifications
You must be signed in to change notification settings - Fork 9.5k
[torchao] Support quantization configs using module swap #21982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[torchao] Support quantization configs using module swap #21982
Conversation
d9ce8c2
to
4de8f17
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request enables support for torchao
quantization methods that require module swaps, like AWQ, by wrapping the linear layer in an nn.Sequential
. The change is well-contained and accompanied by a relevant test case. I've suggested one improvement to make the code more robust against future changes in torchao
or different quantization configurations.
dummy_linear[0].weight = param | ||
quantize_(dummy_linear, torchao_config) | ||
return dummy_linear.weight | ||
return dummy_linear[0].weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After quantization, the module may have been swapped. Instead of directly accessing dummy_linear[0].weight
, retrieve the weight by inspecting the module's parameters. This avoids making fragile assumptions about the internal structure of the quantized module, which may change in future torchao
versions.
dummy_linear[0].weight = param
quantize_(dummy_linear, torchao_config)
# After quantization, the module may have been swapped.
# We retrieve the single parameter, which is the quantized weight.
params = list(dummy_linear.parameters())
assert len(params) == 1, (
"Expected the dummy module to have exactly one parameter after "
f"quantization, but found {len(params)}."
)
return params[0].data
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
4de8f17
to
f1473ee
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable to me, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
Can you merge from main to fix the CI failures? |
Head branch was pushed to by a user without write access
f1473ee
to
a3112b0
Compare
… swap Summary: Current torchao integration quantizes the weights by wrapping weights in a top level linear module and use quantize_ to quantize it, this works for quantization methods that do inplace changes to the weight itself, such as int4, float8, but there are quantization configs that would need module swap, such as awq, that's not supported, in order to support these, we wrap the linear in nn.Sequential so it is no longer a top level module and can be swapped to another module. Test Plan: uplodated an awq checkpoint: https://huggingface.co/torchao-testing/Phi-4-mini-instruct-int4wo-awq-0.13-dev and we test by loading the checkpoint ``` python tests/quantization/test_torchao.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Jerry Zhang <[email protected]>
a3112b0
to
f511cfd
Compare
the current error is because we just landed AWQ updates in torchao recently and it's not picked up by nightly yet |
Summary:
Current torchao integration quantizes the weights by wrapping weights in a top level linear module and use quantize_ to quantize it, this works for quantization methods that do inplace changes to the weight itself, such as int4, float8, since these only do inplace changes to the linear module itself
but there are quantization configs that would need module swap, such as awq, that's not supported, in order to support these, we wrap the linear in nn.Sequential so it is no longer a top level module and can be swapped to another module.
Test Plan:
uplodated an awq checkpoint: https://huggingface.co/torchao-testing/Phi-4-mini-instruct-int4wo-awq-0.13-dev and we test by loading the checkpoint
Reviewers:
Subscribers:
Tasks:
Tags: