From f511cfd8295b6f60a5ae416a1d7ce82a4fc95ace Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 30 Jul 2025 18:50:49 -0700 Subject: [PATCH] Allow torchao quant to support quantization configs relying on module swap Summary: Current torchao integration quantizes the weights by wrapping weights in a top level linear module and use quantize_ to quantize it, this works for quantization methods that do inplace changes to the weight itself, such as int4, float8, but there are quantization configs that would need module swap, such as awq, that's not supported, in order to support these, we wrap the linear in nn.Sequential so it is no longer a top level module and can be swapped to another module. Test Plan: uplodated an awq checkpoint: https://huggingface.co/torchao-testing/Phi-4-mini-instruct-int4wo-awq-0.13-dev and we test by loading the checkpoint ``` python tests/quantization/test_torchao.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Jerry Zhang --- tests/quantization/test_torchao.py | 15 +++++++++++++++ .../layers/quantization/torchao.py | 16 +++++++++------- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index eef3568efea1..c84608fe6f2f 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -75,5 +75,20 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner): print(output) +@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") +def test_phi4mini_int4wo_awq_model_loading_with_params(vllm_runner): + torch._dynamo.reset() + model_name = "torchao-testing/Qwen3-4B-int4wo-awq-0.13-dev" + with vllm_runner(model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0") as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=32) + + assert output + print(output) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 63b2ab6bab06..3498d2994c2a 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -152,18 +152,20 @@ def torchao_quantize_param_data(param: torch.Tensor, from torchao.quantization import quantize_ assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}" - """ - Avoid real weight allocation for faster load, since we will + """ + Avoid real weight allocation for faster load, since we will end up setting it to param. """ with torch.device("meta"): - dummy_linear = torch.nn.Linear(param.shape[1], - param.shape[0], - bias=False) + # linear can't be top level module since quantize_ is inplace + # while some of our configs need to do module swap, and only non-top + # level modules support module swap + dummy_linear = torch.nn.Sequential( + torch.nn.Linear(param.shape[1], param.shape[0], bias=False)) - dummy_linear.weight = param + dummy_linear[0].weight = param quantize_(dummy_linear, torchao_config) - return dummy_linear.weight + return dummy_linear[0].weight class TorchAOLinearMethod(LinearMethodBase):