Skip to content

Commit 28b4136

Browse files
committed
Add NVFP4 QAT
**Summary:** This commit adds a QAT flow for NVFP4, following the numerics in `NVFP4Tensor` closely but without the dtyping casting, swizzling, and the packing/unpacking. Users can call this flow as follows: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig qat_config = QATConfig( weight_config=NVFP4FakeQuantizeConfig(), step="prepare", ) quantize_(model, qat_config) ``` **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 ``` ghstack-source-id: 5548756 Pull Request resolved: #2666
1 parent 5c0d6a3 commit 28b4136

File tree

7 files changed

+127
-15
lines changed

7 files changed

+127
-15
lines changed

docs/source/api_ref_qat.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Custom QAT APIs
2727
FakeQuantizeConfigBase
2828
IntxFakeQuantizeConfig
2929
Float8FakeQuantizeConfig
30+
NVFP4FakeQuantizeConfig
3031
FakeQuantizedLinear
3132
FakeQuantizedEmbedding
3233
FakeQuantizerBase

test/quantization/test_qat.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from torchao.quantization.qat.fake_quantize_config import (
5151
Float8FakeQuantizeConfig,
5252
IntxFakeQuantizeConfig,
53+
NVFP4FakeQuantizeConfig,
5354
)
5455
from torchao.quantization.qat.fake_quantizer import (
5556
Float8FakeQuantizer,
@@ -118,8 +119,8 @@ def __init__(self):
118119
self.sub = Sub()
119120
self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float)
120121

121-
def example_inputs(self):
122-
return (torch.randn(1, 512).to(torch.float),)
122+
def example_inputs(self, device: torch.device = None):
123+
return (torch.randn((1, 512), device=device).to(torch.float),)
123124

124125
def _get_all_weight_scales(self) -> List[torch.Tensor]:
125126
return [
@@ -1932,6 +1933,29 @@ def test_quantize_api_fp8_int4(self):
19321933
target_convert_sqnr=float("inf"),
19331934
)
19341935

1936+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1937+
@parametrize("use_per_tensor_scale", [True, False])
1938+
def test_qat_nvfp4(self, use_per_tensor_scale: bool):
1939+
"""
1940+
Test QAT with `NVFP4FakeQuantizeConfig`.
1941+
"""
1942+
torch.manual_seed(self.SEED)
1943+
m = M().cuda()
1944+
baseline_model = copy.deepcopy(m)
1945+
qat_config = QATConfig(
1946+
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
1947+
step="prepare",
1948+
)
1949+
quantize_(m, qat_config)
1950+
1951+
# Compare prepared values
1952+
torch.manual_seed(self.SEED)
1953+
x = m.example_inputs("cuda")
1954+
out = m(*x)
1955+
baseline_out = baseline_model(*x)
1956+
sqnr = compute_error(out, baseline_out).item()
1957+
self.assertGreater(sqnr, 100)
1958+
19351959

19361960
instantiate_parametrized_tests(TestQAT)
19371961

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,10 @@ def nvfp4_to_copy(func, types, args, kwargs):
310310

311311
if dtype is not None:
312312
res = NVFP4Tensor(
313+
tensor.qdata,
313314
tensor._scale_e4m3,
314315
tensor._per_tensor_scale,
315316
tensor._act_per_tensor_scale,
316-
tensor._data,
317317
tensor._block_size,
318318
dtype,
319319
tensor._is_swizzled_scales,
@@ -749,13 +749,23 @@ def nvfp4_quantize(
749749
AssertionError: If input dtype is not supported, tensor size is not
750750
divisible by block_size, tensor is not contiguous, or block_size != 16
751751
"""
752+
return _nvfp4_quantize(data_hp, block_size, per_tensor_scale)
753+
754+
755+
def _nvfp4_quantize(
756+
data_hp: torch.Tensor,
757+
block_size: int = 16,
758+
per_tensor_scale: Optional[torch.Tensor] = None,
759+
skip_dtype_cast_and_packing: bool = False,
760+
) -> tuple[torch.Tensor, torch.Tensor]:
752761
assert data_hp.dtype in (torch.bfloat16, torch.float), (
753762
f"{data_hp.dtype} not supported"
754763
)
755764
assert data_hp.size(-1) % block_size == 0, "K dim must be divisible by block_size"
756765
assert data_hp.is_contiguous(), "Only support contiguous data for now"
757766
assert block_size == 16, "NVFP4 requires block_size=16"
758767

768+
orig_dtype = data_hp.dtype
759769
orig_shape = data_hp.shape
760770
# Convert to float32 early for consistent precision with Triton implementation
761771
data_hp = data_hp.float().reshape(orig_shape[0], -1, block_size)
@@ -767,9 +777,9 @@ def nvfp4_quantize(
767777
out_scales = None
768778
if per_tensor_scale is None:
769779
# We are doing single level scaling
770-
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to(
771-
torch.float8_e4m3fn
772-
)
780+
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX)
781+
if not skip_dtype_cast_and_packing:
782+
block_scale_fp8 = block_scale_fp8.to(torch.float8_e4m3fn)
773783
block_scale_fp32 = block_scale_fp8.to(torch.float32)
774784
data_scaled = data_hp / block_scale_fp32.unsqueeze(-1)
775785
out_scales = block_scale_fp8
@@ -782,7 +792,9 @@ def nvfp4_quantize(
782792
scaled_block_scales = block_scale_fp32 / per_tensor_scale
783793
scaled_block_scales_fp8 = torch.clamp(
784794
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
785-
).to(torch.float8_e4m3fn)
795+
)
796+
if not skip_dtype_cast_and_packing:
797+
scaled_block_scales_fp8 = scaled_block_scales_fp8.to(torch.float8_e4m3fn)
786798
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
787799
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
788800
# To apply to data
@@ -792,8 +804,11 @@ def nvfp4_quantize(
792804

793805
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
794806
data_scaled = data_scaled.view(orig_shape)
795-
data_lp = f32_to_f4_unpacked(data_scaled)
796-
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
797-
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
798-
data_lp = pack_uint4(data_lp)
799-
return out_scales, data_lp
807+
if skip_dtype_cast_and_packing:
808+
return out_scales, data_scaled.to(orig_dtype)
809+
else:
810+
data_lp = f32_to_f4_unpacked(data_scaled)
811+
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
812+
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
813+
data_lp = pack_uint4(data_lp)
814+
return out_scales, data_lp

torchao/quantization/qat/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
FakeQuantizeConfigBase,
1818
Float8FakeQuantizeConfig,
1919
IntxFakeQuantizeConfig,
20+
NVFP4FakeQuantizeConfig,
2021
)
2122
from .fake_quantizer import (
2223
FakeQuantizer,
2324
FakeQuantizerBase,
2425
Float8FakeQuantizer,
2526
IntxFakeQuantizer,
27+
NVFP4FakeQuantizer,
2628
)
2729
from .linear import (
2830
FakeQuantizedLinear,
@@ -40,6 +42,8 @@
4042
"Float8FakeQuantizer",
4143
"IntxFakeQuantizeConfig",
4244
"IntxFakeQuantizer",
45+
"NVFP4FakeQuantizeConfig",
46+
"NVFP4FakeQuantizer",
4347
"FakeQuantizedLinear",
4448
"FakeQuantizedEmbedding",
4549
# Prototype

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,22 @@ def __post_init__(self):
7777
)
7878

7979

80+
@dataclass
81+
class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
82+
"""
83+
Config for fake quantizing weights or activations to NVIDIA's NVFP4 format
84+
according to https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.
85+
86+
Fake quantization numerics follow `NVFP4Tensor` closely: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py.
87+
88+
Args:
89+
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
90+
after the initial fp8 (e4m3) block-wise scaling (default True)
91+
"""
92+
93+
use_per_tensor_scale: bool = True
94+
95+
8096
@dataclass
8197
class IntxFakeQuantizeConfig(FakeQuantizeConfigBase):
8298
"""

torchao/quantization/qat/fake_quantizer.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
FakeQuantizeConfigBase,
3535
Float8FakeQuantizeConfig,
3636
IntxFakeQuantizeConfig,
37+
NVFP4FakeQuantizeConfig,
3738
)
3839
from .utils import (
3940
_fake_quantize_per_channel_group,
@@ -59,8 +60,10 @@ def __repr__(self) -> str:
5960
def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":
6061
if isinstance(config, IntxFakeQuantizeConfig):
6162
return IntxFakeQuantizer(config)
62-
if isinstance(config, Float8FakeQuantizeConfig):
63+
elif isinstance(config, Float8FakeQuantizeConfig):
6364
return Float8FakeQuantizer(config)
65+
elif isinstance(config, NVFP4FakeQuantizeConfig):
66+
return NVFP4FakeQuantizer(config)
6467
else:
6568
raise ValueError(f"Unknown config type: {config}")
6669

@@ -73,6 +76,7 @@ class Float8FakeQuantizer(FakeQuantizerBase):
7376
def __init__(self, config: Float8FakeQuantizeConfig):
7477
super().__init__()
7578
self.config = config
79+
torch._C._log_api_usage_once("torchao.quantization.qat.Float8FakeQuantizer")
7680

7781
def forward(self, x: torch.Tensor) -> torch.Tensor:
7882
original_dtype = x.dtype
@@ -91,14 +95,60 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9195
return dq
9296

9397

98+
class NVFP4FakeQuantizer(FakeQuantizerBase):
99+
"""
100+
Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
101+
"""
102+
103+
def __init__(self, config: NVFP4FakeQuantizeConfig):
104+
super().__init__()
105+
torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer")
106+
self.config = config
107+
108+
def forward(self, x: torch.Tensor) -> torch.Tensor:
109+
from torchao.prototype.mx_formats.nvfp4_tensor import (
110+
_nvfp4_quantize,
111+
per_tensor_amax_to_scale,
112+
)
113+
114+
block_size = 16
115+
original_shape = x.shape
116+
if x.dim() == 3:
117+
x = x.view(-1, x.shape[-1])
118+
if self.config.use_per_tensor_scale:
119+
tensor_amax = torch.max(torch.abs(x))
120+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
121+
else:
122+
per_tensor_scale = None
123+
124+
# quantize
125+
scale, q = _nvfp4_quantize(
126+
x,
127+
block_size=block_size,
128+
per_tensor_scale=per_tensor_scale,
129+
skip_dtype_cast_and_packing=True,
130+
)
131+
if self.config.use_per_tensor_scale:
132+
scale = scale * per_tensor_scale
133+
assert q.dtype == x.dtype
134+
assert scale.dtype == torch.float32
135+
136+
# dequantize
137+
M, K = q.shape[0], q.shape[1]
138+
q = q.view(M, K // block_size, block_size)
139+
scale = scale.view(M, K // block_size, 1)
140+
dq = q * scale
141+
return dq.view(original_shape).to(x.dtype)
142+
143+
94144
class IntxFakeQuantizer(FakeQuantizerBase):
95145
"""
96146
Generic module for applying integer fake quantization to a tensor, as specified in the config.
97147
"""
98148

99149
def __init__(self, config: IntxFakeQuantizeConfig):
100150
super().__init__()
101-
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizer")
151+
torch._C._log_api_usage_once("torchao.quantization.qat.IntxFakeQuantizer")
102152
self.config = config
103153
self.enabled = True
104154
self.scale: Optional[torch.Tensor] = None

torchao/quantization/qat/linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def __init__(
9292

9393
# initialize weight fake quantizer
9494
if weight_config is not None:
95-
if isinstance(weight_config.granularity, PerGroup):
95+
if isinstance(weight_config, IntxFakeQuantizeConfig) and isinstance(
96+
weight_config.granularity, PerGroup
97+
):
9698
group_size = weight_config.group_size
9799
if group_size is not None and in_features % group_size != 0:
98100
raise ValueError(

0 commit comments

Comments
 (0)