-
Notifications
You must be signed in to change notification settings - Fork 342
Add NVFP4 QAT #2666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NVFP4 QAT #2666
Changes from 46 commits
c358b1b
d076264
8f56651
7a9fe90
1e88ebf
12e8c3f
2ed5e50
019b665
b0c4721
0eb0983
c91b218
affc74e
3069075
b41f4e7
12d920b
68728d7
52f72a5
56415d6
08f87af
be45ff4
28b3b41
9baae23
62cd942
1c30bbb
2fbfbb6
87175e9
8f4953f
a85be2e
7473c23
807a60b
7ad639b
32b6fa7
92a26ac
a1a4e51
f9c7ccd
15d43e0
6fc9be3
908abe2
da9946c
483be23
ff675ee
8b4092b
5b41e42
cec5159
c881741
1ff5de1
362706d
732fb16
cda3a85
80cc501
bc2a059
a024e29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
FakeQuantizeConfigBase, | ||
Float8FakeQuantizeConfig, | ||
IntxFakeQuantizeConfig, | ||
NVFP4FakeQuantizeConfig, | ||
) | ||
from .utils import ( | ||
_fake_quantize_per_channel_group, | ||
|
@@ -59,8 +60,10 @@ def __repr__(self) -> str: | |
def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase": | ||
if isinstance(config, IntxFakeQuantizeConfig): | ||
return IntxFakeQuantizer(config) | ||
if isinstance(config, Float8FakeQuantizeConfig): | ||
elif isinstance(config, Float8FakeQuantizeConfig): | ||
return Float8FakeQuantizer(config) | ||
elif isinstance(config, NVFP4FakeQuantizeConfig): | ||
return NVFP4FakeQuantizer(config) | ||
else: | ||
raise ValueError(f"Unknown config type: {config}") | ||
|
||
|
@@ -73,6 +76,7 @@ class Float8FakeQuantizer(FakeQuantizerBase): | |
def __init__(self, config: Float8FakeQuantizeConfig): | ||
super().__init__() | ||
self.config = config | ||
torch._C._log_api_usage_once("torchao.quantization.qat.Float8FakeQuantizer") | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
original_dtype = x.dtype | ||
|
@@ -91,14 +95,60 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return dq | ||
|
||
|
||
class NVFP4FakeQuantizer(FakeQuantizerBase): | ||
""" | ||
Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config. | ||
""" | ||
|
||
def __init__(self, config: NVFP4FakeQuantizeConfig): | ||
super().__init__() | ||
torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer") | ||
self.config = config | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
from torchao.prototype.mx_formats.nvfp4_tensor import ( | ||
_nvfp4_quantize, | ||
per_tensor_amax_to_scale, | ||
) | ||
|
||
block_size = 16 | ||
original_shape = x.shape | ||
if x.dim() == 3: | ||
x = x.view(-1, x.shape[-1]) | ||
|
||
if self.config.use_per_tensor_scale: | ||
tensor_amax = torch.max(torch.abs(x)) | ||
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) | ||
else: | ||
per_tensor_scale = None | ||
|
||
# quantize | ||
scale, q = _nvfp4_quantize( | ||
|
||
x, | ||
block_size=block_size, | ||
per_tensor_scale=per_tensor_scale, | ||
skip_dtype_cast_and_packing=True, | ||
) | ||
if self.config.use_per_tensor_scale: | ||
scale = scale * per_tensor_scale | ||
drisspg marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
assert q.dtype == x.dtype | ||
assert scale.dtype == torch.float32 | ||
|
||
# dequantize | ||
M, K = q.shape[0], q.shape[1] | ||
q = q.view(M, K // block_size, block_size) | ||
scale = scale.view(M, K // block_size, 1) | ||
dq = q * scale | ||
return dq.view(original_shape).to(x.dtype) | ||
|
||
|
||
class IntxFakeQuantizer(FakeQuantizerBase): | ||
""" | ||
Generic module for applying integer fake quantization to a tensor, as specified in the config. | ||
""" | ||
|
||
def __init__(self, config: IntxFakeQuantizeConfig): | ||
super().__init__() | ||
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizer") | ||
torch._C._log_api_usage_once("torchao.quantization.qat.IntxFakeQuantizer") | ||
self.config = config | ||
self.enabled = True | ||
self.scale: Optional[torch.Tensor] = None | ||
|
Uh oh!
There was an error while loading. Please reload this page.