Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
c358b1b
[bc-breaking] Generalize FakeQuantizeConfig beyond intx
andrewor14 Jul 29, 2025
d076264
New multi-step QAT API
andrewor14 Jul 29, 2025
8f56651
Update on "New multi-step QAT API"
andrewor14 Jul 29, 2025
7a9fe90
Update on "New multi-step QAT API"
andrewor14 Jul 30, 2025
1e88ebf
Deprecate old QAT APIs
andrewor14 Jul 30, 2025
12e8c3f
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
2ed5e50
Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
019b665
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
b0c4721
Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
0eb0983
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
c91b218
Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
affc74e
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
3069075
Update on "Deprecate old QAT APIs"
andrewor14 Jul 30, 2025
b41f4e7
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
12d920b
Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
68728d7
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
52f72a5
Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
56415d6
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
08f87af
Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
be45ff4
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
28b3b41
Update on "Deprecate old QAT APIs"
andrewor14 Jul 31, 2025
9baae23
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Aug 1, 2025
62cd942
Update on "Deprecate old QAT APIs"
andrewor14 Aug 1, 2025
1c30bbb
Update base for Update on "Deprecate old QAT APIs"
andrewor14 Aug 1, 2025
2fbfbb6
Update on "Deprecate old QAT APIs"
andrewor14 Aug 1, 2025
87175e9
Add NVFP4 QAT
andrewor14 Aug 1, 2025
8f4953f
Update base for Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
a85be2e
Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
7473c23
Update base for Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
807a60b
Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
7ad639b
Update base for Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
32b6fa7
Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
92a26ac
Update base for Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
a1a4e51
Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
f9c7ccd
Update base for Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
15d43e0
Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
6fc9be3
Update base for Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
908abe2
Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
da9946c
Update base for Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
483be23
Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
ff675ee
Update base for Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
8b4092b
Update on "Add NVFP4 QAT"
andrewor14 Aug 18, 2025
5b41e42
Update base for Update on "Add NVFP4 QAT"
andrewor14 Aug 22, 2025
cec5159
Update on "Add NVFP4 QAT"
andrewor14 Aug 22, 2025
c881741
Update base for Update on "Add NVFP4 QAT"
andrewor14 Aug 22, 2025
1ff5de1
Update on "Add NVFP4 QAT"
andrewor14 Aug 22, 2025
362706d
Update base for Update on "Add NVFP4 QAT"
andrewor14 Aug 25, 2025
732fb16
Update on "Add NVFP4 QAT"
andrewor14 Aug 25, 2025
cda3a85
Update base for Update on "Add NVFP4 QAT"
andrewor14 Aug 25, 2025
80cc501
Update on "Add NVFP4 QAT"
andrewor14 Aug 25, 2025
bc2a059
Update base for Update on "Add NVFP4 QAT"
andrewor14 Aug 25, 2025
a024e29
Update on "Add NVFP4 QAT"
andrewor14 Aug 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def __init__(self):
self.sub = Sub()
self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 512).to(torch.float),)
def example_inputs(self, device: torch.device = None):
return (torch.randn((1, 512), device=device).to(torch.float),)

def _get_all_weight_scales(self) -> List[torch.Tensor]:
return [
Expand Down Expand Up @@ -1928,7 +1928,7 @@ def test_quantize_api_fp8_int4(self):
"""
self._test_quantize_api_against_ptq(
Float8DynamicActivationInt4WeightConfig(),
target_prepare_sqnr=15,
target_prepare_sqnr=12,
target_convert_sqnr=float("inf"),
)

Expand All @@ -1952,6 +1952,47 @@ def test_infer_fp8_int4_config(self):
self.assertEqual(weight_config.group_size, 128)
self.assertTrue(weight_config.is_symmetric)

@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
def test_quantize_api_nvfp4(self):
"""
Test the following:
quantize_(model, QATConfig(NVFP4InferenceConfig(), step="prepare"))
quantize_(model, QATConfig(NVFP4InferenceConfig(), step="convert"))
"""
from torchao.prototype.mx_formats import NVFP4InferenceConfig

self._test_quantize_api_against_ptq(
NVFP4InferenceConfig(),
target_prepare_sqnr=8,
target_convert_sqnr=float("inf"),
)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@parametrize("use_per_tensor_scale", [True, False])
def test_qat_nvfp4(self, use_per_tensor_scale: bool):
"""
Test QAT with `NVFP4FakeQuantizeConfig`.
"""
from torchao.prototype.qat import NVFP4FakeQuantizeConfig

torch.manual_seed(self.SEED)
m = M().cuda()
baseline_model = copy.deepcopy(m)
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
step="prepare",
)
quantize_(m, qat_config)

# Compare prepared values
torch.manual_seed(self.SEED)
x = m.example_inputs("cuda")
out = m(*x)
baseline_out = baseline_model(*x)
sqnr = compute_error(out, baseline_out).item()
self.assertGreater(sqnr, 24)


instantiate_parametrized_tests(TestQAT)

Expand Down
47 changes: 36 additions & 11 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,13 +751,37 @@ def nvfp4_quantize(
AssertionError: If input dtype is not supported, tensor size is not
divisible by block_size, tensor is not contiguous, or block_size != 16
"""
return _nvfp4_quantize(data_hp, block_size, per_tensor_scale)


class _Float8Round(torch.autograd.Function):
"""
Cast a tensor to float8 and back to float32 with backward STE.
"""

@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return x.to(torch.float8_e4m3fn).to(torch.float32)

@staticmethod
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
return gy


def _nvfp4_quantize(
data_hp: torch.Tensor,
block_size: int = 16,
per_tensor_scale: Optional[torch.Tensor] = None,
skip_dtype_cast_and_packing: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert data_hp.dtype in (torch.bfloat16, torch.float), (
f"{data_hp.dtype} not supported"
)
assert data_hp.size(-1) % block_size == 0, "K dim must be divisible by block_size"
assert data_hp.is_contiguous(), "Only support contiguous data for now"
assert block_size == 16, "NVFP4 requires block_size=16"

orig_dtype = data_hp.dtype
orig_shape = data_hp.shape
# Convert to float32 early for consistent precision with Triton implementation
data_hp = data_hp.float().reshape(orig_shape[0], -1, block_size)
Expand All @@ -769,10 +793,8 @@ def nvfp4_quantize(
out_scales = None
if per_tensor_scale is None:
# We are doing single level scaling
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to(
torch.float8_e4m3fn
)
block_scale_fp32 = block_scale_fp8.to(torch.float32)
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX)
block_scale_fp32 = _Float8Round.apply(block_scale_fp8)
data_scaled = data_hp / block_scale_fp32.unsqueeze(-1)
out_scales = block_scale_fp8
else:
Expand All @@ -784,8 +806,8 @@ def nvfp4_quantize(
scaled_block_scales = block_scale_fp32 / per_tensor_scale
scaled_block_scales_fp8 = torch.clamp(
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
).to(torch.float8_e4m3fn)
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
)
scaled_block_scales_fp32 = _Float8Round.apply(scaled_block_scales_fp8)
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
# To apply to data
total_scale = per_tensor_scale * scaled_block_scales_fp32
Expand All @@ -794,8 +816,11 @@ def nvfp4_quantize(

data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
data_scaled = data_scaled.view(orig_shape)
data_lp = f32_to_f4_unpacked(data_scaled)
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
data_lp = pack_uint4(data_lp)
return out_scales, data_lp
if skip_dtype_cast_and_packing:
return out_scales.to(torch.float32), data_scaled.to(orig_dtype)
else:
data_lp = f32_to_f4_unpacked(data_scaled)
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
data_lp = pack_uint4(data_lp)
return out_scales.to(torch.float8_e4m3fn), data_lp
12 changes: 12 additions & 0 deletions torchao/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Temporary location for prototype QAT features that will
# eventually live in torchao/quantization/qat

from .nvfp4 import (
NVFP4FakeQuantizeConfig,
NVFP4FakeQuantizer,
)

__all__ = [
"NVFP4FakeQuantizeConfig",
"NVFP4FakeQuantizer",
]
69 changes: 69 additions & 0 deletions torchao/prototype/qat/nvfp4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from dataclasses import dataclass

import torch

from torchao.prototype.mx_formats.nvfp4_tensor import (
_nvfp4_quantize,
per_tensor_amax_to_scale,
)
from torchao.quantization.qat import (
FakeQuantizeConfigBase,
FakeQuantizerBase,
)


@dataclass
class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
"""
Config for fake quantizing weights or activations to NVIDIA's NVFP4 format
according to https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.

Fake quantization numerics follow `NVFP4Tensor` closely: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py.

Args:
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
after the initial fp8 (e4m3) block-wise scaling (default True)
"""

use_per_tensor_scale: bool = True


class NVFP4FakeQuantizer(FakeQuantizerBase):
"""
(Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
"""

def __init__(self, config: NVFP4FakeQuantizeConfig):
super().__init__()
torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer")
self.config = config

def forward(self, x: torch.Tensor) -> torch.Tensor:
block_size = 16
original_shape = x.shape
if x.dim() == 3:
x = x.view(-1, x.shape[-1])
if self.config.use_per_tensor_scale:
tensor_amax = torch.max(torch.abs(x))
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
else:
per_tensor_scale = None

# quantize
scale, q = _nvfp4_quantize(
x,
block_size=block_size,
per_tensor_scale=per_tensor_scale,
skip_dtype_cast_and_packing=True,
)
if self.config.use_per_tensor_scale:
scale = scale * per_tensor_scale
assert q.dtype == x.dtype
assert scale.dtype == torch.float32

# dequantize
M, K = q.shape[0], q.shape[1]
q = q.view(M, K // block_size, block_size)
scale = scale.view(M, K // block_size, 1)
dq = q * scale
return dq.view(original_shape).to(x.dtype)
20 changes: 19 additions & 1 deletion torchao/quantization/qat/fake_quantize_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def __post_init__(self):
_log_deprecation_warning(self)


# TODO: rewrite using registration API?
def _infer_fake_quantize_configs(
base_config: AOBaseConfig,
) -> Tuple[Optional[FakeQuantizeConfigBase], Optional[FakeQuantizeConfigBase]]:
Expand All @@ -331,7 +330,15 @@ def _infer_fake_quantize_configs(

Return a 2-tuple of (activation_config, weight_config) for fake quantization.
"""
# TODO: rewrite using registration API so we don't need to import here
# avoid circular imports
from torchao.prototype.mx_formats import (
NVFP4InferenceConfig,
NVFP4MMConfig,
)
from torchao.prototype.qat import (
NVFP4FakeQuantizeConfig,
)
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8DynamicActivationInt4WeightConfig,
Expand Down Expand Up @@ -385,6 +392,17 @@ def _infer_fake_quantize_configs(
group_size=128,
is_symmetric=True,
)
elif isinstance(base_config, NVFP4InferenceConfig):
# Note: today the PTQ config does not allow the user to specify
# `per_tensor_scales` due to serialization concerns. In the future
# we may add a way to compute these dynamically (for activations),
# but for now QAT will mimic the existing behavior of not having
# `per_tensor_scales` (subject to change)
if NVFP4MMConfig.DYNAMIC:
act_config = NVFP4FakeQuantizeConfig(False)
else:
act_config = None
weight_config = NVFP4FakeQuantizeConfig(False)
else:
raise ValueError("Unexpected base config: %s" % base_config)
return (act_config, weight_config)
13 changes: 11 additions & 2 deletions torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,18 @@ def __repr__(self) -> str:

@staticmethod
def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":
# TODO: rewrite using registration API so we don't need to import here
from torchao.prototype.qat import (
NVFP4FakeQuantizeConfig,
NVFP4FakeQuantizer,
)

if isinstance(config, IntxFakeQuantizeConfig):
return IntxFakeQuantizer(config)
if isinstance(config, Float8FakeQuantizeConfig):
elif isinstance(config, Float8FakeQuantizeConfig):
return Float8FakeQuantizer(config)
elif isinstance(config, NVFP4FakeQuantizeConfig):
return NVFP4FakeQuantizer(config)
else:
raise ValueError(f"Unknown config type: {config}")

Expand All @@ -73,6 +81,7 @@ class Float8FakeQuantizer(FakeQuantizerBase):
def __init__(self, config: Float8FakeQuantizeConfig):
super().__init__()
self.config = config
torch._C._log_api_usage_once("torchao.quantization.qat.Float8FakeQuantizer")

def forward(self, x: torch.Tensor) -> torch.Tensor:
original_dtype = x.dtype
Expand All @@ -98,7 +107,7 @@ class IntxFakeQuantizer(FakeQuantizerBase):

def __init__(self, config: IntxFakeQuantizeConfig):
super().__init__()
torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizer")
torch._C._log_api_usage_once("torchao.quantization.qat.IntxFakeQuantizer")
self.config = config
self.enabled = True
self.scale: Optional[torch.Tensor] = None
Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def __init__(

# initialize weight fake quantizer
if weight_config is not None:
if isinstance(weight_config.granularity, PerGroup):
if isinstance(weight_config, IntxFakeQuantizeConfig) and isinstance(
weight_config.granularity, PerGroup
):
group_size = weight_config.group_size
if group_size is not None and in_features % group_size != 0:
raise ValueError(
Expand Down
Loading