Skip to content

Add NVFP4 QAT #2666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/andrewor14/16/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_ref_qat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Custom QAT APIs

FakeQuantizeConfigBase
IntxFakeQuantizeConfig
NVFP4FakeQuantizeConfig
FakeQuantizedLinear
FakeQuantizedEmbedding
FakeQuantizer
Expand Down
32 changes: 30 additions & 2 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings
from typing import List

import pytest
import torch
import torch.nn.functional as F
from parameterized import parameterized
Expand Down Expand Up @@ -44,6 +45,7 @@
)
from torchao.quantization.qat.fake_quantize_config import (
IntxFakeQuantizeConfig,
NVFP4FakeQuantizeConfig,
)
from torchao.quantization.qat.fake_quantizer import (
FakeQuantizer,
Expand Down Expand Up @@ -112,8 +114,8 @@ def __init__(self):
self.sub = Sub()
self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 512).to(torch.float),)
def example_inputs(self, device: torch.device = None):
return (torch.randn((1, 512), device=device).to(torch.float),)

def _get_all_weight_qparams(self) -> List[torch.Tensor]:
return [
Expand Down Expand Up @@ -1884,6 +1886,32 @@ def test_qat_api_deprecation(self):
str(w.message),
)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@pytest.mark.parametrize("use_per_tensor_scale", [True, False])
def test_qat_nvfp4(self, use_per_tensor_scale: bool = False):
"""
Test QAT with `NVFP4FakeQuantizeConfig`.
"""
torch.manual_seed(self.SEED)
m = M().cuda()
baseline_model = copy.deepcopy(m)
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
step="prepare",
)
quantize_(m, qat_config)

# Compare prepared values
torch.manual_seed(self.SEED)
x = m.example_inputs("cuda")
out = m(*x)
baseline_out = baseline_model(*x)
sqnr = compute_error(out, baseline_out).item()
# Use same SQNR threshold as `test_nvfp4_reconstruction`
# TODO: why is this 0.0 when `use_per_tensor_scale=True`?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a seems to be a bug it use_per_tensor_scale should be higher, probably supposed to be 10.0

self.assertGreater(sqnr, 8.0)


if __name__ == "__main__":
unittest.main()
32 changes: 23 additions & 9 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,15 @@ def nvfp4_quantize(
AssertionError: If input dtype is not supported, tensor size is not
divisible by block_size, tensor is not contiguous, or block_size != 16
"""
return _nvfp4_quantize(data_hp, block_size, per_tensor_scale)


def _nvfp4_quantize(
data_hp: torch.Tensor,
block_size: int = 16,
per_tensor_scale: Optional[torch.Tensor] = None,
skip_dtype_cast_and_packing: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert data_hp.dtype in (torch.bfloat16, torch.float), (
f"{data_hp.dtype} not supported"
)
Expand All @@ -782,9 +791,9 @@ def nvfp4_quantize(
out_scales = None
if per_tensor_scale is None:
# We are doing single level scaling
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to(
torch.float8_e4m3fn
)
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX)
if not skip_dtype_cast_and_packing:
block_scale_fp8 = block_scale_fp8.to(torch.float8_e4m3fn)
block_scale_fp32 = block_scale_fp8.to(torch.float32)
data_scaled = data_hp / block_scale_fp32.unsqueeze(-1)
out_scales = block_scale_fp8
Expand All @@ -797,7 +806,9 @@ def nvfp4_quantize(
scaled_block_scales = block_scale_fp32 / per_tensor_scale
scaled_block_scales_fp8 = torch.clamp(
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
).to(torch.float8_e4m3fn)
)
if not skip_dtype_cast_and_packing:
scaled_block_scales_fp8 = scaled_block_scales_fp8.to(torch.float8_e4m3fn)
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
# To apply to data
Expand All @@ -807,8 +818,11 @@ def nvfp4_quantize(

data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
data_scaled = data_scaled.view(orig_shape)
data_lp = f32_to_f4_unpacked(data_scaled)
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
data_lp = pack_uint4(data_lp)
return out_scales, data_lp
if skip_dtype_cast_and_packing:
return out_scales, data_scaled
else:
data_lp = f32_to_f4_unpacked(data_scaled)
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
data_lp = pack_uint4(data_lp)
return out_scales, data_lp
16 changes: 16 additions & 0 deletions torchao/quantization/qat/fake_quantize_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ class FakeQuantizeConfigBase(abc.ABC):
pass


@dataclass
class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
"""
Config for fake quantizing weights or activations to NVIDIA's NVFP4 format
according to https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.

Fake quantization numerics follow `NVFP4Tensor` closely: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py.

Args:
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
after the initial fp8 (e4m3) block-wise scaling.
"""

use_per_tensor_scale: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should default to true



@dataclass
class IntxFakeQuantizeConfig(FakeQuantizeConfigBase):
"""
Expand Down
67 changes: 53 additions & 14 deletions torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .fake_quantize_config import (
FakeQuantizeConfigBase,
IntxFakeQuantizeConfig,
NVFP4FakeQuantizeConfig,
)
from .utils import (
_fake_quantize_per_channel_group,
Expand All @@ -46,13 +47,14 @@ def __init__(self, config: FakeQuantizeConfigBase):
super().__init__()
self.config = config
self.enabled = True
self.scale: Optional[torch.Tensor] = None
self.zero_point: Optional[torch.Tensor] = None

# For range learning only
# TODO: make this configurable?
self._scale_eps = 1e-9
self._initialized = False
if isinstance(self.config, IntxFakeQuantizeConfig):
self.scale: Optional[torch.Tensor] = None
self.zero_point: Optional[torch.Tensor] = None
# For range learning only
# TODO: make this configurable?
self._scale_eps = 1e-9
self._initialized = False

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -62,9 +64,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return x

if not isinstance(self.config, IntxFakeQuantizeConfig):
raise ValueError("Only IntxFakeQuantizeConfig is supported currently")
if isinstance(self.config, NVFP4FakeQuantizeConfig):
return self._nvfp4_forward(x)
elif isinstance(self.config, IntxFakeQuantizeConfig):
return self._intx_forward(x)
else:
raise ValueError(f"Unexpected config type {self.config}")

def _nvfp4_forward(self, x: torch.Tensor):
"""
Apply NVFP4 fake quantization to the tensor following `NVFP4Tensor`.
"""
from torchao.prototype.mx_formats.nvfp4_tensor import (
_nvfp4_quantize,
per_tensor_amax_to_scale,
)

block_size = 16
if self.config.use_per_tensor_scale:
tensor_amax = torch.max(torch.abs(x))
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
else:
per_tensor_scale = None
scale, q = _nvfp4_quantize(
x,
block_size=block_size,
per_tensor_scale=per_tensor_scale,
skip_dtype_cast_and_packing=True,
)
assert q.dtype == x.dtype
assert scale.dtype == torch.float32
M, K = q.shape[0], q.shape[1]
q = q.view(M, K // block_size, block_size)
scale = scale.view(M, K // block_size, 1)
dq = q * scale
return dq.view(x.shape)

def _intx_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply intx fake quantization to the tensor.
"""
if (
self.config.range_learning
and not self._initialized
Expand All @@ -77,15 +116,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)

if isinstance(self.config.granularity, PerToken):
return self._per_token_forward(x)
return self._intx_per_token_forward(x)
elif isinstance(self.config.granularity, (PerAxis, PerGroup)):
return self._per_channel_or_group_forward(x)
return self._intx_per_channel_or_group_forward(x)
else:
raise ValueError("Unknown granularity '%s'" % self.config.granularity)

def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor:
def _intx_per_token_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform per token fake quantization on the tensor.
Perform intx per token fake quantization on the tensor.
"""
if self.config.is_symmetric:
raise NotImplementedError("Symmetric per token is not supported yet")
Expand All @@ -105,9 +144,9 @@ def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor:
self._maybe_update_qparams_for_range_learning()
return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax)

def _per_channel_or_group_forward(self, x: torch.Tensor) -> torch.Tensor:
def _intx_per_channel_or_group_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform per channel or per group fake quantization on the tensor.
Perform intx per channel or per group fake quantization on the tensor.
We express per channel using per group where the group size is the size
of the last dimension of the tensor.
"""
Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def __init__(

# initialize weight fake quantizer
if weight_config is not None:
if isinstance(weight_config.granularity, PerGroup):
if isinstance(weight_config, IntxFakeQuantizeConfig) and isinstance(
weight_config.granularity, PerGroup
):
group_size = weight_config.group_size
if group_size is not None and in_features % group_size != 0:
raise ValueError(
Expand Down
Loading