-
Notifications
You must be signed in to change notification settings - Fork 342
Add NVFP4 QAT #2666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add NVFP4 QAT #2666
Changes from all commits
Commits
Show all changes
52 commits
Select commit
Hold shift + click to select a range
c358b1b
[bc-breaking] Generalize FakeQuantizeConfig beyond intx
andrewor14 d076264
New multi-step QAT API
andrewor14 8f56651
Update on "New multi-step QAT API"
andrewor14 7a9fe90
Update on "New multi-step QAT API"
andrewor14 1e88ebf
Deprecate old QAT APIs
andrewor14 12e8c3f
Update base for Update on "Deprecate old QAT APIs"
andrewor14 2ed5e50
Update on "Deprecate old QAT APIs"
andrewor14 019b665
Update base for Update on "Deprecate old QAT APIs"
andrewor14 b0c4721
Update on "Deprecate old QAT APIs"
andrewor14 0eb0983
Update base for Update on "Deprecate old QAT APIs"
andrewor14 c91b218
Update on "Deprecate old QAT APIs"
andrewor14 affc74e
Update base for Update on "Deprecate old QAT APIs"
andrewor14 3069075
Update on "Deprecate old QAT APIs"
andrewor14 b41f4e7
Update base for Update on "Deprecate old QAT APIs"
andrewor14 12d920b
Update on "Deprecate old QAT APIs"
andrewor14 68728d7
Update base for Update on "Deprecate old QAT APIs"
andrewor14 52f72a5
Update on "Deprecate old QAT APIs"
andrewor14 56415d6
Update base for Update on "Deprecate old QAT APIs"
andrewor14 08f87af
Update on "Deprecate old QAT APIs"
andrewor14 be45ff4
Update base for Update on "Deprecate old QAT APIs"
andrewor14 28b3b41
Update on "Deprecate old QAT APIs"
andrewor14 9baae23
Update base for Update on "Deprecate old QAT APIs"
andrewor14 62cd942
Update on "Deprecate old QAT APIs"
andrewor14 1c30bbb
Update base for Update on "Deprecate old QAT APIs"
andrewor14 2fbfbb6
Update on "Deprecate old QAT APIs"
andrewor14 87175e9
Add NVFP4 QAT
andrewor14 8f4953f
Update base for Update on "Add NVFP4 QAT"
andrewor14 a85be2e
Update on "Add NVFP4 QAT"
andrewor14 7473c23
Update base for Update on "Add NVFP4 QAT"
andrewor14 807a60b
Update on "Add NVFP4 QAT"
andrewor14 7ad639b
Update base for Update on "Add NVFP4 QAT"
andrewor14 32b6fa7
Update on "Add NVFP4 QAT"
andrewor14 92a26ac
Update base for Update on "Add NVFP4 QAT"
andrewor14 a1a4e51
Update on "Add NVFP4 QAT"
andrewor14 f9c7ccd
Update base for Update on "Add NVFP4 QAT"
andrewor14 15d43e0
Update on "Add NVFP4 QAT"
andrewor14 6fc9be3
Update base for Update on "Add NVFP4 QAT"
andrewor14 908abe2
Update on "Add NVFP4 QAT"
andrewor14 da9946c
Update base for Update on "Add NVFP4 QAT"
andrewor14 483be23
Update on "Add NVFP4 QAT"
andrewor14 ff675ee
Update base for Update on "Add NVFP4 QAT"
andrewor14 8b4092b
Update on "Add NVFP4 QAT"
andrewor14 5b41e42
Update base for Update on "Add NVFP4 QAT"
andrewor14 cec5159
Update on "Add NVFP4 QAT"
andrewor14 c881741
Update base for Update on "Add NVFP4 QAT"
andrewor14 1ff5de1
Update on "Add NVFP4 QAT"
andrewor14 362706d
Update base for Update on "Add NVFP4 QAT"
andrewor14 732fb16
Update on "Add NVFP4 QAT"
andrewor14 cda3a85
Update base for Update on "Add NVFP4 QAT"
andrewor14 80cc501
Update on "Add NVFP4 QAT"
andrewor14 bc2a059
Update base for Update on "Add NVFP4 QAT"
andrewor14 a024e29
Update on "Add NVFP4 QAT"
andrewor14 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Temporary location for prototype QAT features that will | ||
# eventually live in torchao/quantization/qat | ||
|
||
from .nvfp4 import ( | ||
NVFP4FakeQuantizeConfig, | ||
NVFP4FakeQuantizer, | ||
) | ||
|
||
__all__ = [ | ||
"NVFP4FakeQuantizeConfig", | ||
"NVFP4FakeQuantizer", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
|
||
from torchao.prototype.mx_formats.nvfp4_tensor import ( | ||
_nvfp4_quantize, | ||
per_tensor_amax_to_scale, | ||
) | ||
from torchao.quantization.qat import ( | ||
FakeQuantizeConfigBase, | ||
FakeQuantizerBase, | ||
) | ||
|
||
|
||
@dataclass | ||
class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase): | ||
""" | ||
Config for fake quantizing weights or activations to NVIDIA's NVFP4 format | ||
according to https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/. | ||
|
||
Fake quantization numerics follow `NVFP4Tensor` closely: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py. | ||
|
||
Args: | ||
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling | ||
after the initial fp8 (e4m3) block-wise scaling (default True) | ||
""" | ||
|
||
use_per_tensor_scale: bool = True | ||
|
||
|
||
class NVFP4FakeQuantizer(FakeQuantizerBase): | ||
""" | ||
(Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config. | ||
""" | ||
|
||
def __init__(self, config: NVFP4FakeQuantizeConfig): | ||
super().__init__() | ||
torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer") | ||
self.config = config | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
block_size = 16 | ||
original_shape = x.shape | ||
if x.dim() == 3: | ||
x = x.view(-1, x.shape[-1]) | ||
if self.config.use_per_tensor_scale: | ||
tensor_amax = torch.max(torch.abs(x)) | ||
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) | ||
else: | ||
per_tensor_scale = None | ||
|
||
# quantize | ||
scale, q = _nvfp4_quantize( | ||
x, | ||
block_size=block_size, | ||
per_tensor_scale=per_tensor_scale, | ||
skip_dtype_cast_and_packing=True, | ||
) | ||
if self.config.use_per_tensor_scale: | ||
scale = scale * per_tensor_scale | ||
assert q.dtype == x.dtype | ||
assert scale.dtype == torch.float32 | ||
|
||
# dequantize | ||
M, K = q.shape[0], q.shape[1] | ||
q = q.view(M, K // block_size, block_size) | ||
scale = scale.view(M, K // block_size, 1) | ||
dq = q * scale | ||
return dq.view(original_shape).to(x.dtype) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.