Skip to content

Commit 2fd06de

Browse files
authored
Fix NVFP4 to_copy (#2812)
* Fix NVFP4 to_copy **Summary:** Fixes #2811 **Test Plan:** ``` pytest test/prototype/mx_formats/test_nvfp4_tensor.py -k to_copy ``` * Update test_nvfp4_tensor.py
1 parent 9978bca commit 2fd06de

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,3 +523,25 @@ def test_nvfp4_matmul_with_amax(
523523
assert sqnr >= SQNR_THRESHOLD, (
524524
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, mm_config={mm_config}, compile={compile}, bias={bias}"
525525
)
526+
527+
528+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
529+
@pytest.mark.skipif(
530+
not TORCH_VERSION_AT_LEAST_2_8, reason="NVFP4 requires PyTorch 2.8+"
531+
)
532+
def test_nvfp4_to_copy():
533+
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
534+
535+
x = NVFP4Tensor.to_nvfp4(torch.randn((32, 128))).cuda()
536+
y = torch.ops.aten._to_copy(x, dtype=torch.bfloat16)
537+
assert torch.equal(x.qdata, y.qdata)
538+
assert torch.equal(x._scale_e4m3, y._scale_e4m3)
539+
assert x._per_tensor_scale is None
540+
assert y._per_tensor_scale is None
541+
assert x._act_per_tensor_scale is None
542+
assert y._act_per_tensor_scale is None
543+
assert x._block_size == y._block_size
544+
assert x.use_triton_kernel == y.use_triton_kernel
545+
assert x.act_quant_kwargs == y.act_quant_kwargs
546+
assert x.dtype == torch.float32
547+
assert y.dtype == torch.bfloat16

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,10 @@ def nvfp4_to_copy(func, types, args, kwargs):
310310

311311
if dtype is not None:
312312
res = NVFP4Tensor(
313+
tensor.qdata,
313314
tensor._scale_e4m3,
314315
tensor._per_tensor_scale,
315316
tensor._act_per_tensor_scale,
316-
tensor._data,
317317
tensor._block_size,
318318
dtype,
319319
tensor._is_swizzled_scales,

0 commit comments

Comments
 (0)