diff --git a/docs/source/api_ref_qat.rst b/docs/source/api_ref_qat.rst index 73106939cd..8383f7b558 100644 --- a/docs/source/api_ref_qat.rst +++ b/docs/source/api_ref_qat.rst @@ -26,6 +26,7 @@ Custom QAT APIs FakeQuantizeConfigBase IntxFakeQuantizeConfig + NVFP4FakeQuantizeConfig FakeQuantizedLinear FakeQuantizedEmbedding FakeQuantizer diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 48a9f780b6..a99056cd39 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -12,6 +12,7 @@ import warnings from typing import List +import pytest import torch import torch.nn.functional as F from parameterized import parameterized @@ -44,6 +45,7 @@ ) from torchao.quantization.qat.fake_quantize_config import ( IntxFakeQuantizeConfig, + NVFP4FakeQuantizeConfig, ) from torchao.quantization.qat.fake_quantizer import ( FakeQuantizer, @@ -112,8 +114,8 @@ def __init__(self): self.sub = Sub() self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float) - def example_inputs(self): - return (torch.randn(1, 512).to(torch.float),) + def example_inputs(self, device: torch.device = None): + return (torch.randn((1, 512), device=device).to(torch.float),) def _get_all_weight_qparams(self) -> List[torch.Tensor]: return [ @@ -1884,6 +1886,32 @@ def test_qat_api_deprecation(self): str(w.message), ) + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @pytest.mark.parametrize("use_per_tensor_scale", [True, False]) + def test_qat_nvfp4(self, use_per_tensor_scale: bool = False): + """ + Test QAT with `NVFP4FakeQuantizeConfig`. + """ + torch.manual_seed(self.SEED) + m = M().cuda() + baseline_model = copy.deepcopy(m) + qat_config = QATConfig( + activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale), + weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale), + step="prepare", + ) + quantize_(m, qat_config) + + # Compare prepared values + torch.manual_seed(self.SEED) + x = m.example_inputs("cuda") + out = m(*x) + baseline_out = baseline_model(*x) + sqnr = compute_error(out, baseline_out).item() + # Use same SQNR threshold as `test_nvfp4_reconstruction` + # TODO: why is this 0.0 when `use_per_tensor_scale=True`? + self.assertGreater(sqnr, 8.0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 221017b5f4..d3fb5becc7 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -764,6 +764,15 @@ def nvfp4_quantize( AssertionError: If input dtype is not supported, tensor size is not divisible by block_size, tensor is not contiguous, or block_size != 16 """ + return _nvfp4_quantize(data_hp, block_size, per_tensor_scale) + + +def _nvfp4_quantize( + data_hp: torch.Tensor, + block_size: int = 16, + per_tensor_scale: Optional[torch.Tensor] = None, + skip_dtype_cast_and_packing: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: assert data_hp.dtype in (torch.bfloat16, torch.float), ( f"{data_hp.dtype} not supported" ) @@ -782,9 +791,9 @@ def nvfp4_quantize( out_scales = None if per_tensor_scale is None: # We are doing single level scaling - block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to( - torch.float8_e4m3fn - ) + block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX) + if not skip_dtype_cast_and_packing: + block_scale_fp8 = block_scale_fp8.to(torch.float8_e4m3fn) block_scale_fp32 = block_scale_fp8.to(torch.float32) data_scaled = data_hp / block_scale_fp32.unsqueeze(-1) out_scales = block_scale_fp8 @@ -797,7 +806,9 @@ def nvfp4_quantize( scaled_block_scales = block_scale_fp32 / per_tensor_scale scaled_block_scales_fp8 = torch.clamp( scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX - ).to(torch.float8_e4m3fn) + ) + if not skip_dtype_cast_and_packing: + scaled_block_scales_fp8 = scaled_block_scales_fp8.to(torch.float8_e4m3fn) scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32) # We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale # To apply to data @@ -807,8 +818,11 @@ def nvfp4_quantize( data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX) data_scaled = data_scaled.view(orig_shape) - data_lp = f32_to_f4_unpacked(data_scaled) - # TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2' - # data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2) - data_lp = pack_uint4(data_lp) - return out_scales, data_lp + if skip_dtype_cast_and_packing: + return out_scales, data_scaled + else: + data_lp = f32_to_f4_unpacked(data_scaled) + # TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2' + # data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2) + data_lp = pack_uint4(data_lp) + return out_scales, data_lp diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index 554ed2a065..d40d60b237 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -36,6 +36,22 @@ class FakeQuantizeConfigBase(abc.ABC): pass +@dataclass +class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase): + """ + Config for fake quantizing weights or activations to NVIDIA's NVFP4 format + according to https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/. + + Fake quantization numerics follow `NVFP4Tensor` closely: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py. + + Args: + use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling + after the initial fp8 (e4m3) block-wise scaling. + """ + + use_per_tensor_scale: bool = False + + @dataclass class IntxFakeQuantizeConfig(FakeQuantizeConfigBase): """ diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index 3cb873f3ff..4d80ae5443 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -29,6 +29,7 @@ from .fake_quantize_config import ( FakeQuantizeConfigBase, IntxFakeQuantizeConfig, + NVFP4FakeQuantizeConfig, ) from .utils import ( _fake_quantize_per_channel_group, @@ -46,13 +47,14 @@ def __init__(self, config: FakeQuantizeConfigBase): super().__init__() self.config = config self.enabled = True - self.scale: Optional[torch.Tensor] = None - self.zero_point: Optional[torch.Tensor] = None - # For range learning only - # TODO: make this configurable? - self._scale_eps = 1e-9 - self._initialized = False + if isinstance(self.config, IntxFakeQuantizeConfig): + self.scale: Optional[torch.Tensor] = None + self.zero_point: Optional[torch.Tensor] = None + # For range learning only + # TODO: make this configurable? + self._scale_eps = 1e-9 + self._initialized = False def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -62,9 +64,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.enabled: return x - if not isinstance(self.config, IntxFakeQuantizeConfig): - raise ValueError("Only IntxFakeQuantizeConfig is supported currently") + if isinstance(self.config, NVFP4FakeQuantizeConfig): + return self._nvfp4_forward(x) + elif isinstance(self.config, IntxFakeQuantizeConfig): + return self._intx_forward(x) + else: + raise ValueError(f"Unexpected config type {self.config}") + + def _nvfp4_forward(self, x: torch.Tensor): + """ + Apply NVFP4 fake quantization to the tensor following `NVFP4Tensor`. + """ + from torchao.prototype.mx_formats.nvfp4_tensor import ( + _nvfp4_quantize, + per_tensor_amax_to_scale, + ) + block_size = 16 + if self.config.use_per_tensor_scale: + tensor_amax = torch.max(torch.abs(x)) + per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) + else: + per_tensor_scale = None + scale, q = _nvfp4_quantize( + x, + block_size=block_size, + per_tensor_scale=per_tensor_scale, + skip_dtype_cast_and_packing=True, + ) + assert q.dtype == x.dtype + assert scale.dtype == torch.float32 + M, K = q.shape[0], q.shape[1] + q = q.view(M, K // block_size, block_size) + scale = scale.view(M, K // block_size, 1) + dq = q * scale + return dq.view(x.shape) + + def _intx_forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply intx fake quantization to the tensor. + """ if ( self.config.range_learning and not self._initialized @@ -77,15 +116,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) if isinstance(self.config.granularity, PerToken): - return self._per_token_forward(x) + return self._intx_per_token_forward(x) elif isinstance(self.config.granularity, (PerAxis, PerGroup)): - return self._per_channel_or_group_forward(x) + return self._intx_per_channel_or_group_forward(x) else: raise ValueError("Unknown granularity '%s'" % self.config.granularity) - def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor: + def _intx_per_token_forward(self, x: torch.Tensor) -> torch.Tensor: """ - Perform per token fake quantization on the tensor. + Perform intx per token fake quantization on the tensor. """ if self.config.is_symmetric: raise NotImplementedError("Symmetric per token is not supported yet") @@ -105,9 +144,9 @@ def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor: self._maybe_update_qparams_for_range_learning() return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax) - def _per_channel_or_group_forward(self, x: torch.Tensor) -> torch.Tensor: + def _intx_per_channel_or_group_forward(self, x: torch.Tensor) -> torch.Tensor: """ - Perform per channel or per group fake quantization on the tensor. + Perform intx per channel or per group fake quantization on the tensor. We express per channel using per group where the group size is the size of the last dimension of the tensor. """ diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index c9c8f8ea5d..6773d02784 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -90,7 +90,9 @@ def __init__( # initialize weight fake quantizer if weight_config is not None: - if isinstance(weight_config.granularity, PerGroup): + if isinstance(weight_config, IntxFakeQuantizeConfig) and isinstance( + weight_config.granularity, PerGroup + ): group_size = weight_config.group_size if group_size is not None and in_features % group_size != 0: raise ValueError(