Skip to content

Commit f511cfd

Browse files
committed
Allow torchao quant to support quantization configs relying on module swap
Summary: Current torchao integration quantizes the weights by wrapping weights in a top level linear module and use quantize_ to quantize it, this works for quantization methods that do inplace changes to the weight itself, such as int4, float8, but there are quantization configs that would need module swap, such as awq, that's not supported, in order to support these, we wrap the linear in nn.Sequential so it is no longer a top level module and can be swapped to another module. Test Plan: uplodated an awq checkpoint: https://huggingface.co/torchao-testing/Phi-4-mini-instruct-int4wo-awq-0.13-dev and we test by loading the checkpoint ``` python tests/quantization/test_torchao.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Jerry Zhang <[email protected]>
1 parent 9659bc7 commit f511cfd

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

tests/quantization/test_torchao.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,20 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
7575
print(output)
7676

7777

78+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
79+
def test_phi4mini_int4wo_awq_model_loading_with_params(vllm_runner):
80+
torch._dynamo.reset()
81+
model_name = "torchao-testing/Qwen3-4B-int4wo-awq-0.13-dev"
82+
with vllm_runner(model_name=model_name,
83+
quantization="torchao",
84+
dtype="bfloat16",
85+
pt_load_map_location="cuda:0") as llm:
86+
output = llm.generate_greedy(["The capital of France is"],
87+
max_tokens=32)
88+
89+
assert output
90+
print(output)
91+
92+
7893
if __name__ == "__main__":
7994
pytest.main([__file__])

vllm/model_executor/layers/quantization/torchao.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,18 +152,20 @@ def torchao_quantize_param_data(param: torch.Tensor,
152152
from torchao.quantization import quantize_
153153

154154
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
155-
"""
156-
Avoid real weight allocation for faster load, since we will
155+
"""
156+
Avoid real weight allocation for faster load, since we will
157157
end up setting it to param.
158158
"""
159159
with torch.device("meta"):
160-
dummy_linear = torch.nn.Linear(param.shape[1],
161-
param.shape[0],
162-
bias=False)
160+
# linear can't be top level module since quantize_ is inplace
161+
# while some of our configs need to do module swap, and only non-top
162+
# level modules support module swap
163+
dummy_linear = torch.nn.Sequential(
164+
torch.nn.Linear(param.shape[1], param.shape[0], bias=False))
163165

164-
dummy_linear.weight = param
166+
dummy_linear[0].weight = param
165167
quantize_(dummy_linear, torchao_config)
166-
return dummy_linear.weight
168+
return dummy_linear[0].weight
167169

168170

169171
class TorchAOLinearMethod(LinearMethodBase):

0 commit comments

Comments
 (0)