2525from torch ._inductor .test_case import TestCase as InductorTestCase
2626from torch .testing ._internal import common_utils
2727
28- from torchao .dtypes .floatx .float8_layout import Float8AQTTensorImpl , preprocess_scale
28+ from torchao .dtypes .floatx .float8_layout import preprocess_scale
2929from torchao .float8 .float8_utils import compute_error
3030from torchao .quantization import (
3131 Float8DynamicActivationFloat8WeightConfig ,
32+ Float8Tensor ,
3233 float8_dynamic_activation_float8_weight ,
3334 float8_weight_only ,
3435 quantize_ ,
@@ -236,12 +237,8 @@ def test_serialization(self, mode: str):
236237 new_layer = getattr (new_model , layer_name )
237238
238239 # Compare weights
239- if mode == "weight-only" :
240- original_weight = original_layer .weight .tensor_impl .float8_data .to (
241- torch .float32
242- )
243- new_weight = new_layer .weight .tensor_impl .float8_data .to (torch .float32 )
244- else :
240+ if mode == "static" :
241+ # note: we haven't migrated static quant to the new API
245242 original_weight = original_layer .weight .original_weight_tensor .tensor_impl .float8_data .to (
246243 torch .float32
247244 )
@@ -250,6 +247,9 @@ def test_serialization(self, mode: str):
250247 torch .float32
251248 )
252249 )
250+ else :
251+ original_weight = original_layer .weight .float8_data .to (torch .float32 )
252+ new_weight = new_layer .weight .float8_data .to (torch .float32 )
253253
254254 assert torch .allclose (original_weight , new_weight ), (
255255 f"Weights do not match for { layer_name } "
@@ -324,19 +324,15 @@ def test_mm_float8dq_per_row(
324324
325325 quant_weight = test_linear .weight
326326
327- self .assertTrue (hasattr (quant_weight , "original_weight_tensor" ))
328- weight_impl = quant_weight .original_weight_tensor .tensor_impl
329-
330- self .assertTrue (hasattr (weight_impl , "float8_data" ))
331- self .assertTrue (hasattr (weight_impl , "scale" ))
332- self .assertFalse (weight_impl .transposed )
327+ self .assertTrue (hasattr (quant_weight , "float8_data" ))
328+ self .assertTrue (hasattr (quant_weight , "scale" ))
333329
334330 # Verify scale shape for row-wise quantization
335331 expected_scale_shape = (out_features , 1 )
336- actual_scale_shape = weight_impl .scale .shape
332+ actual_scale_shape = quant_weight .scale .shape
337333 self .assertEqual (actual_scale_shape , expected_scale_shape )
338334
339- self .assertEqual (weight_impl .float8_data .shape , (out_features , in_features ))
335+ self .assertEqual (quant_weight .float8_data .shape , (out_features , in_features ))
340336
341337 input_tensor = torch .randn (* input_shape , device = device , dtype = dtype )
342338
@@ -357,7 +353,7 @@ def test_mm_float8dq_per_row(
357353 @common_utils .parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
358354 @common_utils .parametrize ("output_dtype" , [torch .float32 , torch .bfloat16 ])
359355 @common_utils .parametrize ("block_size" , [None , (1 , 32 ), (2 , 16 ), (4 , 8 )])
360- def test_dequantize_affine_float8 (self , float8_dtype , output_dtype , block_size ):
356+ def test__dequantize_affine_float8 (self , float8_dtype , output_dtype , block_size ):
361357 """Test _dequantize_affine_float8 with various configurations"""
362358
363359 device = "cuda"
@@ -387,7 +383,7 @@ def test_dequantize_affine_float8(self, float8_dtype, output_dtype, block_size):
387383 @unittest .skipIf (
388384 not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
389385 )
390- def test_dequantize_affine_float8_scale_broadcasting (self ):
386+ def test__dequantize_affine_float8_scale_broadcasting (self ):
391387 """Test that scale broadcasting works correctly for block-wise quantization"""
392388 device = "cuda"
393389 # Create input tensor with known block structure
@@ -431,24 +427,24 @@ def test_float8_tensor_slicing_basic(self, granularity):
431427 model , Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
432428 )
433429
434- weight_impl = model .weight . original_weight_tensor . tensor_impl
430+ weight = model .weight
435431
436432 # Test dimension 0 slicing (rows)
437- sliced_0 = weight_impl [10 :20 ]
433+ sliced_0 = weight [10 :20 ]
438434 self .assertEqual (sliced_0 .shape , (10 , 64 ))
439435
440436 # Test dimension 1 slicing (columns)
441- sliced_1 = weight_impl [:, 20 :40 ]
437+ sliced_1 = weight [:, 20 :40 ]
442438 self .assertEqual (sliced_1 .shape , (32 , 20 ))
443439
444440 # Test combined slicing
445- sliced_both = weight_impl [5 :15 , 10 :30 ]
441+ sliced_both = weight [5 :15 , 10 :30 ]
446442 self .assertEqual (sliced_both .shape , (10 , 20 ))
447443
448444 # Verify the sliced tensors are still Float8 tensors
449- self .assertTrue (isinstance (sliced_0 , Float8AQTTensorImpl ))
450- self .assertTrue (isinstance (sliced_1 , Float8AQTTensorImpl ))
451- self .assertTrue (isinstance (sliced_both , Float8AQTTensorImpl ))
445+ self .assertTrue (isinstance (sliced_0 , Float8Tensor ))
446+ self .assertTrue (isinstance (sliced_1 , Float8Tensor ))
447+ self .assertTrue (isinstance (sliced_both , Float8Tensor ))
452448
453449 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
454450 @unittest .skipIf (
@@ -466,16 +462,15 @@ def test_float8_tensor_slicing_per_tensor(self):
466462 )
467463
468464 original_weight = model .weight
469- original_impl = original_weight .original_weight_tensor .tensor_impl
470- original_scale = original_impl .scale
465+ original_scale = original_weight .scale
471466
472467 # Test slicing
473468 sliced_weight = original_weight [10 :20 , 20 :40 ]
474- sliced_impl = sliced_weight .original_weight_tensor . tensor_impl
469+ sliced_scale = sliced_weight .scale
475470
476471 # For per-tensor quantization, scale should be identical
477- self .assertTrue (torch .equal (original_scale , sliced_impl . scale ))
478- self .assertEqual (sliced_impl . scale .numel (), 1 )
472+ self .assertTrue (torch .equal (original_scale , sliced_scale ))
473+ self .assertEqual (sliced_scale .numel (), 1 )
479474
480475 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
481476 @unittest .skipIf (
@@ -497,27 +492,26 @@ def test_float8_tensor_slicing_per_row(self):
497492 )
498493
499494 original_weight = model .weight # Shape: (32, 64)
500- original_impl = original_weight .original_weight_tensor .tensor_impl
501- original_scale = original_impl .scale # Shape: (32, 1)
495+ original_scale = model .weight .scale # Shape: (32, 1)
502496
503497 # Test row slicing (dimension 0)
504498 sliced_rows = original_weight [10 :20 ] # Shape: (10, 64)
505- sliced_impl = sliced_rows .original_weight_tensor . tensor_impl
499+ sliced_scale = sliced_rows .scale
506500
507501 # Scale should be sliced to match the rows
508502 expected_scale_shape = (10 , 1 )
509- self .assertEqual (sliced_impl . scale .shape , expected_scale_shape )
503+ self .assertEqual (sliced_scale .shape , expected_scale_shape )
510504
511505 # Verify the scale values are correct (should be subset of original)
512- self .assertTrue (torch .equal (sliced_impl . scale , original_scale [10 :20 ]))
506+ self .assertTrue (torch .equal (sliced_scale , original_scale [10 :20 ]))
513507
514508 # Test column slicing (dimension 1) - scale should not change for per-row
515509 sliced_cols = original_weight [:, 20 :40 ] # Shape: (32, 20)
516- sliced_cols_impl = sliced_cols .original_weight_tensor . tensor_impl
510+ sliced_cols_scale = sliced_cols .scale
517511
518512 # Scale shape should remain the same since we're not changing rows
519- self .assertEqual (sliced_cols_impl . scale .shape , (32 , 1 ))
520- self .assertTrue (torch .equal (sliced_cols_impl . scale , original_scale ))
513+ self .assertEqual (sliced_cols_scale .shape , (32 , 1 ))
514+ self .assertTrue (torch .equal (sliced_cols_scale , original_scale ))
521515
522516 @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
523517 @unittest .skipIf (
@@ -552,11 +546,11 @@ def test_float8_tensor_slicing_edge_cases(self):
552546 @unittest .skipIf (
553547 not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
554548 )
555- @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
556549 @unittest .skipIf (
557550 is_sm_version (8 , 9 ),
558551 "TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15" ,
559552 )
553+ @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
560554 def test_float8_tensor_slicing_functional_correctness (self , granularity ):
561555 """Test that sliced tensors produce correct results in computations"""
562556 device = "cuda"
@@ -579,15 +573,16 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
579573 quant_weight_slice = quant_model .weight [0 :16 , 0 :32 ]
580574
581575 # Verify that the sliced weights maintain Float8 properties
582- self .assertTrue (hasattr (quant_weight_slice , "original_weight_tensor" ))
583- sliced_impl = quant_weight_slice .original_weight_tensor .tensor_impl
584- self .assertTrue (isinstance (sliced_impl , Float8AQTTensorImpl ))
576+ self .assertTrue (hasattr (quant_weight_slice , "float8_data" ))
577+ self .assertTrue (hasattr (quant_weight_slice , "scale" ))
578+ sliced_impl = quant_weight_slice
579+ self .assertTrue (isinstance (sliced_impl , Float8Tensor ))
585580
586581 # Verify sliced weight shapes
587582 self .assertEqual (sliced_impl .float8_data .shape , (16 , 32 ))
588583
589584 # Get original quantized weight implementation for scale comparison
590- original_quant_impl = quant_model .weight . original_weight_tensor . tensor_impl
585+ original_quant_impl = quant_model .weight
591586
592587 # Verify scale properties based on granularity
593588 if isinstance (granularity , PerTensor ):
@@ -604,7 +599,7 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
604599 )
605600
606601 # Verify that sliced quantized data matches the correct slice from original
607- original_float8_data_slice = original_quant_impl .float8_data [0 :16 , 0 :32 ]
602+ original_float8_data_slice = quant_model . weight .float8_data [0 :16 , 0 :32 ]
608603 self .assertTrue (
609604 torch .equal (sliced_impl .float8_data , original_float8_data_slice )
610605 )
@@ -675,46 +670,6 @@ def test_preprocess_scale_3d_reshape(self):
675670 expected_shape = (8 , 1 ) # Flattened (2*2*2, 1)
676671 self .assertEqual (result .shape , expected_shape )
677672
678- @common_utils .parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
679- @common_utils .parametrize ("hp_dtype" , [torch .float32 , torch .bfloat16 ])
680- def test_quantize_dequantize_fp8_inductor (self , float8_dtype , hp_dtype ):
681- quantize_affine_float8 = torch .ops .torchao .quantize_affine_float8
682- dequantize_affine_float8 = torch .ops .torchao .dequantize_affine_float8
683- input = torch .randn (10 , 10 )
684- with torch .no_grad ():
685- torch ._dynamo .reset ()
686- expected_scale = torch .tensor (2.0 )
687- expected_quantized = quantize_affine_float8 (
688- input ,
689- expected_scale ,
690- float8_dtype = float8_dtype ,
691- )
692- expected_dequantized = dequantize_affine_float8 (
693- expected_quantized ,
694- expected_scale ,
695- output_dtype = hp_dtype ,
696- )
697- test_q , (code_q ,) = torch ._inductor .utils .run_and_get_code (
698- torch .compile (quantize_affine_float8 ),
699- input ,
700- expected_scale ,
701- float8_dtype = float8_dtype ,
702- )
703- torch .testing .FileCheck ().check (
704- "torch.ops.torchao.quantize_affine_float8.default"
705- ).run (code_q )
706- test_dq , (code_dq ,) = torch ._inductor .utils .run_and_get_code (
707- torch .compile (dequantize_affine_float8 ),
708- test_q ,
709- expected_scale ,
710- hp_dtype ,
711- )
712- torch .testing .FileCheck ().check (
713- "torch.ops.torchao.dequantize_affine_float8.default"
714- ).run (code_dq )
715- torch .testing .assert_close (expected_quantized , test_q )
716- torch .testing .assert_close (expected_dequantized , test_dq )
717-
718673
719674common_utils .instantiate_parametrized_tests (TestAffineQuantizedFloat8Compile )
720675
0 commit comments