Skip to content

Commit d9ce8c2

Browse files
committed
allow torchao quant to support quantization configs relying on module swap
Summary: Current torchao integration quantizes the weights by wrapping weights in a top level linear module and use quantize_ to quantize it, this works for quantization methods that do inplace changes to the weight itself, such as int4, float8, but there are quantization configs that would need module swap, such as awq, that's not supported, in order to support these, we wrap the linear in nn.Sequential so it is no longer a top level module and can be swapped to another module. Test Plan: uplodated an awq checkpoint: https://huggingface.co/torchao-testing/Phi-4-mini-instruct-int4wo-awq-0.13-dev and we test by loading the checkpoint ``` python tests/quantization/test_torchao.py ``` Reviewers: Subscribers: Tasks: Tags:
1 parent ca9e2be commit d9ce8c2

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

tests/quantization/test_torchao.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,20 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
7474
assert output
7575
print(output)
7676

77+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
78+
def test_phi4mini_int4wo_awq_model_loading_with_params(vllm_runner):
79+
torch._dynamo.reset()
80+
model_name = "torchao-testing/Phi-4-mini-instruct-int4wo-awq-0.13-dev"
81+
with vllm_runner(model_name=model_name,
82+
quantization="torchao",
83+
dtype="bfloat16",
84+
pt_load_map_location="cuda:0") as llm:
85+
output = llm.generate_greedy(["The capital of France is"],
86+
max_tokens=32)
87+
88+
assert output
89+
print(output)
90+
7791

7892
if __name__ == "__main__":
7993
pytest.main([__file__])

vllm/model_executor/layers/quantization/torchao.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,18 +152,25 @@ def torchao_quantize_param_data(param: torch.Tensor,
152152
from torchao.quantization import quantize_
153153

154154
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
155-
"""
156-
Avoid real weight allocation for faster load, since we will
155+
"""
156+
Avoid real weight allocation for faster load, since we will
157157
end up setting it to param.
158158
"""
159159
with torch.device("meta"):
160-
dummy_linear = torch.nn.Linear(param.shape[1],
161-
param.shape[0],
162-
bias=False)
160+
# linear can't be top level module since quantize_ is inplace
161+
# while some of our configs need to do module swap, and only non-top level
162+
# modules support module swap
163+
dummy_linear = torch.nn.Sequential(
164+
torch.nn.Linear(
165+
param.shape[1],
166+
param.shape[0],
167+
bias=False
168+
)
169+
)
163170

164-
dummy_linear.weight = param
171+
dummy_linear[0].weight = param
165172
quantize_(dummy_linear, torchao_config)
166-
return dummy_linear.weight
173+
return dummy_linear[0].weight
167174

168175

169176
class TorchAOLinearMethod(LinearMethodBase):

0 commit comments

Comments
 (0)