Skip to content

Commit e320c79

Browse files
committed
Add NVFP4 QAT
**Summary:** This commit adds a QAT flow for NVFP4, following the numerics in `NVFP4Tensor` closely but without the dtyping casting, swizzling, and the packing/unpacking. Users can call this flow as follows: ``` from torchao.quantization import quantize_ from torchao.quantization.qat import NVFP4FakeQuantizeConfig, QATConfig qat_config = QATConfig( weight_config=NVFP4FakeQuantizeConfig(), step="prepare", ) quantize_(model, qat_config) ``` **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 ``` Initial benchmarks on fine-tuning Qwen3-1.7B on oasst1 for 3 epochs: ``` # Without QAT | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.7927|± | N/A| | | |none |None |byte_perplexity|↓ | 1.7323|± | N/A| | | |none |None |word_perplexity|↓ |18.8815|± | N/A| # With QAT | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.7921|± | N/A| | | |none |None |byte_perplexity|↓ | 1.7316|± | N/A| | | |none |None |word_perplexity|↓ |18.8409|± | N/A| ``` ghstack-source-id: 5548756 Pull Request resolved: #2666
1 parent 2fd06de commit e320c79

File tree

8 files changed

+153
-12
lines changed

8 files changed

+153
-12
lines changed

docs/source/api_ref_qat.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Custom QAT APIs
2727
FakeQuantizeConfigBase
2828
IntxFakeQuantizeConfig
2929
Float8FakeQuantizeConfig
30+
NVFP4FakeQuantizeConfig
3031
FakeQuantizedLinear
3132
FakeQuantizedEmbedding
3233
FakeQuantizerBase

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def test_nvfp4_matmul_with_amax(
527527

528528
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
529529
@pytest.mark.skipif(
530-
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
530+
torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
531531
)
532532
def test_nvfp4_to_copy():
533533
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor

test/quantization/test_qat.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from torchao.quantization.qat.fake_quantize_config import (
5151
Float8FakeQuantizeConfig,
5252
IntxFakeQuantizeConfig,
53+
NVFP4FakeQuantizeConfig,
5354
)
5455
from torchao.quantization.qat.fake_quantizer import (
5556
Float8FakeQuantizer,
@@ -118,8 +119,8 @@ def __init__(self):
118119
self.sub = Sub()
119120
self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float)
120121

121-
def example_inputs(self):
122-
return (torch.randn(1, 512).to(torch.float),)
122+
def example_inputs(self, device: torch.device = None):
123+
return (torch.randn((1, 512), device=device).to(torch.float),)
123124

124125
def _get_all_weight_scales(self) -> List[torch.Tensor]:
125126
return [
@@ -1928,7 +1929,7 @@ def test_quantize_api_fp8_int4(self):
19281929
"""
19291930
self._test_quantize_api_against_ptq(
19301931
Float8DynamicActivationInt4WeightConfig(),
1931-
target_prepare_sqnr=15,
1932+
target_prepare_sqnr=12,
19321933
target_convert_sqnr=float("inf"),
19331934
)
19341935

@@ -1952,6 +1953,45 @@ def test_infer_fp8_int4_config(self):
19521953
self.assertEqual(weight_config.group_size, 128)
19531954
self.assertTrue(weight_config.is_symmetric)
19541955

1956+
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
1957+
def test_quantize_api_nvfp4(self):
1958+
"""
1959+
Test the following:
1960+
quantize_(model, QATConfig(NVFP4InferenceConfig(), step="prepare"))
1961+
quantize_(model, QATConfig(NVFP4InferenceConfig(), step="convert"))
1962+
"""
1963+
from torchao.prototype.mx_formats import NVFP4InferenceConfig
1964+
1965+
self._test_quantize_api_against_ptq(
1966+
NVFP4InferenceConfig(),
1967+
target_prepare_sqnr=8,
1968+
target_convert_sqnr=float("inf"),
1969+
)
1970+
1971+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1972+
@parametrize("use_per_tensor_scale", [True, False])
1973+
def test_qat_nvfp4(self, use_per_tensor_scale: bool):
1974+
"""
1975+
Test QAT with `NVFP4FakeQuantizeConfig`.
1976+
"""
1977+
torch.manual_seed(self.SEED)
1978+
m = M().cuda()
1979+
baseline_model = copy.deepcopy(m)
1980+
qat_config = QATConfig(
1981+
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
1982+
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
1983+
step="prepare",
1984+
)
1985+
quantize_(m, qat_config)
1986+
1987+
# Compare prepared values
1988+
torch.manual_seed(self.SEED)
1989+
x = m.example_inputs("cuda")
1990+
out = m(*x)
1991+
baseline_out = baseline_model(*x)
1992+
sqnr = compute_error(out, baseline_out).item()
1993+
self.assertGreater(sqnr, 30)
1994+
19551995

19561996
instantiate_parametrized_tests(TestQAT)
19571997

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -749,13 +749,23 @@ def nvfp4_quantize(
749749
AssertionError: If input dtype is not supported, tensor size is not
750750
divisible by block_size, tensor is not contiguous, or block_size != 16
751751
"""
752+
return _nvfp4_quantize(data_hp, block_size, per_tensor_scale)
753+
754+
755+
def _nvfp4_quantize(
756+
data_hp: torch.Tensor,
757+
block_size: int = 16,
758+
per_tensor_scale: Optional[torch.Tensor] = None,
759+
skip_dtype_cast_and_packing: bool = False,
760+
) -> tuple[torch.Tensor, torch.Tensor]:
752761
assert data_hp.dtype in (torch.bfloat16, torch.float), (
753762
f"{data_hp.dtype} not supported"
754763
)
755764
assert data_hp.size(-1) % block_size == 0, "K dim must be divisible by block_size"
756765
assert data_hp.is_contiguous(), "Only support contiguous data for now"
757766
assert block_size == 16, "NVFP4 requires block_size=16"
758767

768+
orig_dtype = data_hp.dtype
759769
orig_shape = data_hp.shape
760770
# Convert to float32 early for consistent precision with Triton implementation
761771
data_hp = data_hp.float().reshape(orig_shape[0], -1, block_size)
@@ -792,8 +802,11 @@ def nvfp4_quantize(
792802

793803
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
794804
data_scaled = data_scaled.view(orig_shape)
795-
data_lp = f32_to_f4_unpacked(data_scaled)
796-
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
797-
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
798-
data_lp = pack_uint4(data_lp)
799-
return out_scales, data_lp
805+
if skip_dtype_cast_and_packing:
806+
return out_scales.to(torch.float32), data_scaled.to(orig_dtype)
807+
else:
808+
data_lp = f32_to_f4_unpacked(data_scaled)
809+
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
810+
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
811+
data_lp = pack_uint4(data_lp)
812+
return out_scales, data_lp

torchao/quantization/qat/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
FakeQuantizeConfigBase,
1818
Float8FakeQuantizeConfig,
1919
IntxFakeQuantizeConfig,
20+
NVFP4FakeQuantizeConfig,
2021
)
2122
from .fake_quantizer import (
2223
FakeQuantizer,
2324
FakeQuantizerBase,
2425
Float8FakeQuantizer,
2526
IntxFakeQuantizer,
27+
NVFP4FakeQuantizer,
2628
)
2729
from .linear import (
2830
FakeQuantizedLinear,
@@ -40,6 +42,8 @@
4042
"Float8FakeQuantizer",
4143
"IntxFakeQuantizeConfig",
4244
"IntxFakeQuantizer",
45+
"NVFP4FakeQuantizeConfig",
46+
"NVFP4FakeQuantizer",
4347
"FakeQuantizedLinear",
4448
"FakeQuantizedEmbedding",
4549
# Prototype

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,22 @@ def __post_init__(self):
7777
)
7878

7979

80+
@dataclass
81+
class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
82+
"""
83+
Config for fake quantizing weights or activations to NVIDIA's NVFP4 format
84+
according to https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.
85+
86+
Fake quantization numerics follow `NVFP4Tensor` closely: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py.
87+
88+
Args:
89+
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
90+
after the initial fp8 (e4m3) block-wise scaling (default True)
91+
"""
92+
93+
use_per_tensor_scale: bool = True
94+
95+
8096
@dataclass
8197
class IntxFakeQuantizeConfig(FakeQuantizeConfigBase):
8298
"""
@@ -332,6 +348,10 @@ def _infer_fake_quantize_configs(
332348
Return a 2-tuple of (activation_config, weight_config) for fake quantization.
333349
"""
334350
# avoid circular imports
351+
from torchao.prototype.mx_formats import (
352+
NVFP4InferenceConfig,
353+
NVFP4MMConfig,
354+
)
335355
from torchao.quantization import (
336356
Float8DynamicActivationFloat8WeightConfig,
337357
Float8DynamicActivationInt4WeightConfig,
@@ -385,6 +405,17 @@ def _infer_fake_quantize_configs(
385405
group_size=128,
386406
is_symmetric=True,
387407
)
408+
elif isinstance(base_config, NVFP4InferenceConfig):
409+
# Note: today the PTQ config does not allow the user to specify
410+
# `per_tensor_scales` due to serialization concerns. In the future
411+
# we may add a way to compute these dynamically (for activations),
412+
# but for now QAT will mimic the existing behavior of not having
413+
# `per_tensor_scales` (subject to change)
414+
if NVFP4MMConfig.DYNAMIC:
415+
act_config = NVFP4FakeQuantizeConfig(False)
416+
else:
417+
act_config = None
418+
weight_config = NVFP4FakeQuantizeConfig(False)
388419
else:
389420
raise ValueError("Unexpected base config: %s" % base_config)
390421
return (act_config, weight_config)

torchao/quantization/qat/fake_quantizer.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
FakeQuantizeConfigBase,
3535
Float8FakeQuantizeConfig,
3636
IntxFakeQuantizeConfig,
37+
NVFP4FakeQuantizeConfig,
3738
)
3839
from .utils import (
3940
_fake_quantize_per_channel_group,
@@ -59,8 +60,10 @@ def __repr__(self) -> str:
5960
def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":
6061
if isinstance(config, IntxFakeQuantizeConfig):
6162
return IntxFakeQuantizer(config)
62-
if isinstance(config, Float8FakeQuantizeConfig):
63+
elif isinstance(config, Float8FakeQuantizeConfig):
6364
return Float8FakeQuantizer(config)
65+
elif isinstance(config, NVFP4FakeQuantizeConfig):
66+
return NVFP4FakeQuantizer(config)
6467
else:
6568
raise ValueError(f"Unknown config type: {config}")
6669

@@ -73,6 +76,7 @@ class Float8FakeQuantizer(FakeQuantizerBase):
7376
def __init__(self, config: Float8FakeQuantizeConfig):
7477
super().__init__()
7578
self.config = config
79+
torch._C._log_api_usage_once("torchao.quantization.qat.Float8FakeQuantizer")
7680

7781
def forward(self, x: torch.Tensor) -> torch.Tensor:
7882
original_dtype = x.dtype
@@ -91,14 +95,60 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9195
return dq
9296

9397

98+
class NVFP4FakeQuantizer(FakeQuantizerBase):
99+
"""
100+
Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
101+
"""
102+
103+
def __init__(self, config: NVFP4FakeQuantizeConfig):
104+
super().__init__()
105+
torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer")
106+
self.config = config
107+
108+
def forward(self, x: torch.Tensor) -> torch.Tensor:
109+
from torchao.prototype.mx_formats.nvfp4_tensor import (
110+
_nvfp4_quantize,
111+
per_tensor_amax_to_scale,
112+
)
113+
114+
block_size = 16
115+
original_shape = x.shape
116+
if x.dim() == 3:
117+
x = x.view(-1, x.shape[-1])
118+
if self.config.use_per_tensor_scale:
119+
tensor_amax = torch.max(torch.abs(x))
120+
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
121+
else:
122+
per_tensor_scale = None
123+
124+
# quantize
125+
scale, q = _nvfp4_quantize(
126+
x,
127+
block_size=block_size,
128+
per_tensor_scale=per_tensor_scale,
129+
skip_dtype_cast_and_packing=True,
130+
)
131+
if self.config.use_per_tensor_scale:
132+
scale = scale * per_tensor_scale
133+
assert q.dtype == x.dtype
134+
assert scale.dtype == torch.float32
135+
136+
# dequantize
137+
M, K = q.shape[0], q.shape[1]
138+
q = q.view(M, K // block_size, block_size)
139+
scale = scale.view(M, K // block_size, 1)
140+
dq = q * scale
141+
return dq.view(original_shape).to(x.dtype)
142+
143+
94144
class IntxFakeQuantizer(FakeQuantizerBase):
95145
"""
96146
Generic module for applying integer fake quantization to a tensor, as specified in the config.
97147
"""
98148

99149
def __init__(self, config: IntxFakeQuantizeConfig):
100150
super().__init__()
101-
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizer")
151+
torch._C._log_api_usage_once("torchao.quantization.qat.IntxFakeQuantizer")
102152
self.config = config
103153
self.enabled = True
104154
self.scale: Optional[torch.Tensor] = None

torchao/quantization/qat/linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def __init__(
9292

9393
# initialize weight fake quantizer
9494
if weight_config is not None:
95-
if isinstance(weight_config.granularity, PerGroup):
95+
if isinstance(weight_config, IntxFakeQuantizeConfig) and isinstance(
96+
weight_config.granularity, PerGroup
97+
):
9698
group_size = weight_config.group_size
9799
if group_size is not None and in_features % group_size != 0:
98100
raise ValueError(

0 commit comments

Comments
 (0)