-
Notifications
You must be signed in to change notification settings - Fork 357
Make SmoothQuant more General #2728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: - Added SmoothQuantConfig as a base config and made corresponding changes in other parts of the flow Test Plan: - Qwen 3-8B with example.py and unittest - Additional test plans requirerd ETC - Fix typo in README.md for SmoothQuant
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2728
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit bd6bf13 with merge base 2eae09b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@jerryzh168 Could you please look into this PR? It was inspired by #2659 (comment) for more generalized SmoothQuant API. |
|
Thanks @namgyu-youn this is a step towards that but not fully general yet, it seems to be a quick change to add it though, commented inline. also it seems smoothquant is not very popular at the moment: https://huggingface.co/models?search=smoothquant, so I'd like to wait a bit before we invest more effort to it, let me know if you are interested to contribute more to torchao, we have many more higher priority issues that you can help with I think |
Thanks for the kind info, and I truly love your team's work after reviewing TorchAO: CodeML @ ICML 2025. The recently updated contribution guide could be a great choice for the next contribution, but personally I prefer the sparsity (pruning) module more. Unfortunately, I heard the main POC (@jcaip) is on vacation, making it hard for me to progress. The following are my recent activities related to the sparsity module:
If there is no huge progress for the sparsity module, quantization (new APIs or primitive ops) might be a next step. Let me know if there is a good-second-issue about it. p.s. Could you please check #2644 ? It hasn't merged yet after being approved (no CI broken). Also, #2660 has been waiting for review (I am fine to close this because it is low-priority). |
|
Test result ( |
|
@jerryzh168 Hi, I am happy to show you more generalized SmoothQuant API by using Quantization API ( |
| insert_smooth_quant_observer_(model, alpha, quant_mode) | ||
| # Step 1: Insert observers to find average magnitude and calculate scales | ||
| config = SmoothQuantConfig( | ||
| base_config=int8_dynamic_activation_int8_weight(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can generalize the example API to take quant type configs now, see
ao/torchao/prototype/awq/example.py
Line 307 in 751d7f6
| help="Quantization method. Options are either awq-int4wo-<group_size>, or int4wo-<group_size>.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, but how about using Int8DynamicActivationInt8WeightConfig as a default in here and devide PR? It might require checking which APIs are compatiable with SmoothQuantConfig, and building unittest.
btw, we can uniform commonly used utils functions in AWQ and SmoothQuant: get_calib_dataset, wiki2_eval, and quantize_and_eval.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah sure
| print(f"time for convert: {time.time() - t0:.02f} seconds") | ||
|
|
||
| # Set up config for loading | ||
| quant_config.step = SmoothQuantStep.PREPARE_FOR_LOADING |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work? you can check if it works by the following:
export MODEL=YOUR_SAVED_SMOOTHQUANT_MODEL
lm_eval --model hf --model_args pretrained=$MODEL --tasks $TASK --device cuda:0 --batch_size auto --limit 50
# vllm
export MODEL=YOUR_SAVED_SMOOTHQUANT_MODEL
python benchmarks/benchmark_latency.py --input-len 256 --output-len 256 --model $MODEL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hoped so because it works similarly to AWQ, but just tested it with the following code for assurance and got the log message:
import tempfile
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torchao.prototype.smoothquant import SmoothQuantConfig
from torchao.prototype.smoothquant.core import SmoothQuantStep
from torchao.prototype.smoothquant.example import quantize_and_eval
from torchao.quantization import quantize_
from torchao.quantization.quant_api import Int8DynamicActivationInt8WeightConfig
MODEL_NAME = "microsoft/DialoGPT-small"
# Step 1: Create quantized model
with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f:
model_path = f.name
quantize_and_eval(MODEL_NAME, 0.5, ['PPL'], 256, 5, 'cuda', torch.float32, False, model_path, None)
# Step 2: Test PREPARE_FOR_LOADING
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32).cuda()
quantize_(model, SmoothQuantConfig(
base_config=Int8DynamicActivationInt8WeightConfig(),
step=SmoothQuantStep.PREPARE_FOR_LOADING,
alpha=0.5,
))
# Test inference
test_input = tokenizer('Hello world', return_tensors='pt').to('cuda')
with torch.no_grad():
output = model(**test_input)
generated = model.generate(**test_input, max_length=20, do_sample=False)
print(f"✓ Inference: {output.logits.shape}")
print(f"✓ Generation: {tokenizer.decode(generated[0], skip_special_tokens=True)}")Loading model on cuda...
Time to load model: 1.86 seconds
running SmoothQuant prepare and calibrate
Repo card metadata block was not found. Setting CardData to empty.
Token indices sequence length is longer than the specified maximum sequence length for this model (1443 > 1024). Running this sequence through the model will result in indexing errors
time for prepare and calibration: 5.20 seconds
running SmoothQuant convert
time for convert: 0.04 seconds
Saving model to /tmp/tmpqeme5s1r.pt
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
✓ Inference: torch.Size([1, 4, 50257])
✓ Generation: TorchAO TorchAOFor sure, we should benchmark them with your suggestion, but I want to carefully suggest dividing its PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK sounds good to divide the PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a test for sanity checking the accuracy / functionality of smoothquant implementation, see comments inline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! looks good
|
the error seems to be real: https://github.com/pytorch/ao/actions/runs/17455759907/job/49707357714?pr=2728 you can't import from test files, can you define the ToyLinearModel in test_smoothquant itself? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please revert changes to test_integration.py
also can you run the tests locally first?
Unfortunately, I am unavailable to L40s. Here is the locally tested result in A100 80GB PCIe MIG instance. Result of integration test after revert$ pytest test/integration --verbose -s
===================================================================================================== warnings summary =====================================================================================================
.venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:97
.venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:97
/home/elicer/ao/.venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:97: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details.
warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "
torchao/utils.py:408
/home/elicer/ao/torchao/utils.py:408: UserWarning: TORCH_VERSION_AT_LEAST_2_8 is deprecated and will be removed in torchao 0.14.0
warnings.warn(self.msg)
test/integration/test_integration.py::TestSubclass::test_int8_weight_only_quant_subclass_3_cuda
/home/elicer/ao/.venv/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:282: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
test/integration/test_integration.py: 49 warnings
/home/elicer/ao/torchao/utils.py:408: UserWarning: TORCH_VERSION_AT_LEAST_2_7 is deprecated and will be removed in torchao 0.14.0
warnings.warn(self.msg)
test/integration/test_integration.py::SmoothquantIntegrationTest::test_on_dummy_distilbert
/home/elicer/ao/test/integration/test_integration.py:1429: DeprecationWarning: torch.ao.quantization is deprecated and will be removed in 2.10.
For migrations of users:
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e)
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e)
see https://github.com/pytorch/ao/issues/2259 for more details
model_copy2 = torch.ao.quantization.quantize_dynamic(
test/integration/test_integration.py::SmoothquantIntegrationTest::test_on_dummy_distilbert
/home/elicer/ao/.venv/lib/python3.10/site-packages/torch/ao/quantization/quantize.py:566: DeprecationWarning: torch.ao.quantization is deprecated and will be removed in 2.10.
For migrations of users:
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e)
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e)
see https://github.com/pytorch/ao/issues/2259 for more details
convert(model, mapping, inplace=True)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================ 152 passed, 207 skipped, 55 warnings in 124.51s (0:02:04) =================================================================================Result of SmoothQuant test$ pytest test/prototype/test_smoothquant.py --verbose -s
=================================================================================================== test session starts ====================================================================================================
platform linux -- Python 3.10.15, pytest-8.4.2, pluggy-1.6.0 -- /home/elicer/ao/.venv/bin/python3
cachedir: .pytest_cache
hypothesis profile 'default'
rootdir: /home/elicer/ao
plugins: hypothesis-6.138.14
collecting ... TMA benchmarks will be running without grid constant TMA descriptor.
collected 6 items
test/prototype/test_smoothquant.py::TestSmoothQuant::test_observer_insertion_base_config0 PASSED
test/prototype/test_smoothquant.py::TestSmoothQuant::test_prepare_for_loading_base_config0 PASSED
test/prototype/test_smoothquant.py::TestSmoothQuant::test_smoothquant_accuracy_alpha_0_5_base_config0_device_cpu_bfloat16 convert: module is not SmoothQuantObservedLinear, skipping: <class 'torch.nn.modules.linear.Linear
'> PASSED
test/prototype/test_smoothquant.py::TestSmoothQuant::test_smoothquant_accuracy_alpha_0_5_base_config0_device_cuda_bfloat16 convert: module is not SmoothQuantObservedLinear, skipping: <class 'torch.nn.modules.linear.Linea
r'> PASSED
test/prototype/test_smoothquant.py::TestSmoothQuant::test_smoothquant_accuracy_alpha_0_75_base_config0_device_cpu_bfloat16 convert: module is not SmoothQuantObservedLinear, skipping: <class 'torch.nn.modules.linear.Linea
r'> PASSED
test/prototype/test_smoothquant.py::TestSmoothQuant::test_smoothquant_accuracy_alpha_0_75_base_config0_device_cuda_bfloat16 convert: module is not SmoothQuantObservedLinear, skipping: <class 'torch.nn.modules.linear.Line
ar'> PASSED
===================================================================================================== warnings summary =====================================================================================================
.venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:97
.venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:97
/home/elicer/ao/.venv/lib/python3.10/site-packages/triton/runtime/autotuner.py:97: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for
details. warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================================================== 6 passed, 2 warnings in 4.98s =============================================================================================== |
|
please skip the failed test when there is no cuda like this:
|
Done, could you look into it? |
|
This PR is titled "Make SmoothQuant more General" however it removes static quant support and prevents UT running in CPU-only environments, which actually makes SmoothQuant more specific. |
@Xia-Weiwen We decided to split the PR to support more quantization APIs, as discussed in #2728 (comment). In fact, what "general" refers to here is the new SmoothQuant API structure ( ao/test/prototype/test_smoothquant.py Lines 127 to 164 in 0d3217d
|
Summary
Add
SmoothQuantConfigas a base config andSmoothQuantObserveras a smoothing factor computation. Apply corresponding changes in other parts for the SmoothQuant API flows.Benchmark
All experiments use the
meta-llama/Llama-2-7b-chat-hfmodel with max sequence length (SeqLen) 512 and calibration limit 128 on a 1xH100 80GB HBM2 instance. For comprehensive benchmarking, we compare three cases: 1. origin, 2. W8A8, 3. SmoothQuant (W8A8). Result shows SmoothQuant with W8A8 slightly increase perplexity, reducing latency 33.82%. Since tinygemm kernel only uses bfloat16 inputs, Tokens/sec decreases for float16 input.Test Plan
This PR addresses the prototype benchmark. Experiments are recorded in the "Benchmark" section using
example.pywithLlama-2-7b-chat-hffor both quantization and model saving. Unittest is also updated for the change.Future Plan
Build a benchmark within the vLLM ecosystem for AWQ and SmoothQuant. See #2815 for more info