30
30
from torchao .float8 .float8_utils import compute_error
31
31
from torchao .quantization import (
32
32
Float8DynamicActivationFloat8WeightConfig ,
33
- float8_dynamic_activation_float8_weight ,
34
- float8_weight_only ,
33
+ Float8StaticActivationFloat8WeightConfig ,
34
+ Float8WeightOnlyConfig ,
35
35
quantize_ ,
36
36
)
37
37
from torchao .quantization .granularity import (
38
38
PerRow ,
39
39
PerTensor ,
40
40
)
41
- from torchao .quantization .quant_api import (
42
- float8_static_activation_float8_weight ,
43
- )
44
41
from torchao .quantization .quant_primitives import (
45
42
MappingType ,
46
43
_choose_scale_float8 ,
@@ -117,17 +114,24 @@ def test_fp8_linear_variants(
117
114
torch .float8_e4m3fn ,
118
115
scale_dtype = torch .float32 ,
119
116
)
117
+ fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig .VERSION
118
+ fp8wo_cur_version = Float8WeightOnlyConfig .VERSION
119
+ Float8DynamicActivationFloat8WeightConfig .VERSION = 1
120
+ Float8WeightOnlyConfig .VERSION = 1
120
121
mode_map = {
121
122
"dynamic" : partial (
122
- float8_dynamic_activation_float8_weight , granularity = granularity
123
+ Float8DynamicActivationFloat8WeightConfig ,
124
+ granularity = granularity ,
123
125
),
124
- "weight-only" : float8_weight_only ,
126
+ "weight-only" : Float8WeightOnlyConfig ,
125
127
"static" : partial (
126
- float8_static_activation_float8_weight ,
128
+ Float8StaticActivationFloat8WeightConfig ,
127
129
scale = scale ,
128
130
granularity = granularity ,
129
131
),
130
132
}
133
+ Float8DynamicActivationFloat8WeightConfig .VERSION = fp8_dq_cur_version
134
+ Float8WeightOnlyConfig .VERSION = fp8wo_cur_version
131
135
132
136
# Create a linear layer with bfloat16 dtype
133
137
model = ToyLinearModel (K , N ).eval ().to (dtype ).to ("cuda" )
@@ -152,7 +156,7 @@ def test_fp8_linear_variants(
152
156
)
153
157
def test_invalid_granularity (self ):
154
158
with pytest .raises (ValueError , match = "Invalid granularity specification" ):
155
- float8_dynamic_activation_float8_weight (granularity = "invalid" )
159
+ Float8DynamicActivationFloat8WeightConfig (granularity = "invalid" )
156
160
157
161
@unittest .skipIf (
158
162
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
@@ -162,7 +166,9 @@ def test_mismatched_granularity(self):
162
166
ValueError ,
163
167
match = "Different granularities for activation and weight are not supported" ,
164
168
):
165
- float8_dynamic_activation_float8_weight (granularity = (PerTensor (), PerRow ()))
169
+ Float8DynamicActivationFloat8WeightConfig (
170
+ granularity = (PerTensor (), PerRow ())
171
+ )
166
172
167
173
@unittest .skipIf (
168
174
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
@@ -172,8 +178,8 @@ class UnsupportedGranularity:
172
178
pass
173
179
174
180
with pytest .raises (ValueError , match = "Invalid granularity types" ):
175
- float8_dynamic_activation_float8_weight (
176
- granularity = (UnsupportedGranularity (), UnsupportedGranularity ())
181
+ Float8DynamicActivationFloat8WeightConfig (
182
+ granularity = (UnsupportedGranularity (), UnsupportedGranularity ()),
177
183
)
178
184
179
185
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
@@ -187,7 +193,8 @@ def test_per_row_with_float32(self):
187
193
):
188
194
model = ToyLinearModel (64 , 64 ).eval ().to (torch .float32 ).to ("cuda" )
189
195
quantize_ (
190
- model , float8_dynamic_activation_float8_weight (granularity = PerRow ())
196
+ model ,
197
+ Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ()),
191
198
)
192
199
193
200
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
@@ -199,19 +206,26 @@ def test_serialization(self, mode: str):
199
206
# Create and quantize the model
200
207
model = ToyLinearModel (16 , 32 ).to (device = "cuda" )
201
208
209
+ fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig .VERSION
210
+ fp8wo_cur_version = Float8WeightOnlyConfig .VERSION
211
+ Float8DynamicActivationFloat8WeightConfig .VERSION = 1
212
+ Float8WeightOnlyConfig .VERSION = 1
202
213
mode_map = {
203
214
"dynamic" : partial (
204
- float8_dynamic_activation_float8_weight , granularity = PerTensor ()
215
+ Float8DynamicActivationFloat8WeightConfig ,
216
+ granularity = PerTensor (),
205
217
),
206
- "weight-only" : float8_weight_only ,
218
+ "weight-only" : Float8WeightOnlyConfig ,
207
219
"static" : partial (
208
- float8_static_activation_float8_weight ,
220
+ Float8StaticActivationFloat8WeightConfig ,
209
221
scale = torch .tensor (1.0 , dtype = torch .float32 , device = "cuda" ),
210
222
granularity = PerTensor (),
211
223
),
212
224
}
225
+
213
226
factory = mode_map [mode ]()
214
227
quantize_ (model , factory )
228
+ print ("model:" , model )
215
229
216
230
# Save the state dict to an in-memory buffer
217
231
buffer = io .BytesIO ()
@@ -262,6 +276,10 @@ def test_serialization(self, mode: str):
262
276
original_layer .weight .scale , new_layer .weight .scale
263
277
), f"Scales do not match for { layer_name } "
264
278
279
+ # restore in the end
280
+ Float8DynamicActivationFloat8WeightConfig .VERSION = fp8_dq_cur_version
281
+ Float8WeightOnlyConfig .VERSION = fp8wo_cur_version
282
+
265
283
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
266
284
@unittest .skipIf (
267
285
not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
@@ -274,9 +292,13 @@ def test_fp8_weight_dimension_warning(self):
274
292
with self .assertLogs (
275
293
"torchao.quantization.quant_api" , level = "INFO"
276
294
) as log_context :
295
+ fp8wo_cur_version = Float8WeightOnlyConfig .VERSION
296
+ Float8DynamicActivationFloat8WeightConfig .VERSION = 1
277
297
quantize_ (
278
- model , float8_dynamic_activation_float8_weight (granularity = PerTensor ())
298
+ model ,
299
+ Float8DynamicActivationFloat8WeightConfig (granularity = PerTensor ()),
279
300
)
301
+ Float8DynamicActivationFloat8WeightConfig .VERSION = fp8wo_cur_version
280
302
print (model )
281
303
282
304
# Verify warning messages for both layers
@@ -319,9 +341,13 @@ def test_mm_float8dq_per_row(
319
341
torch .nn .Linear (in_features , out_features , bias = bias ).to (device ).to (dtype )
320
342
)
321
343
test_linear = copy .deepcopy (ref_linear )
344
+ fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig .VERSION
345
+ Float8DynamicActivationFloat8WeightConfig .VERSION = 1
322
346
quantize_ (
323
- test_linear , Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
347
+ test_linear ,
348
+ Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ()),
324
349
)
350
+ Float8DynamicActivationFloat8WeightConfig .VERSION = fp8_dq_cur_version
325
351
326
352
quant_weight = test_linear .weight
327
353
@@ -471,9 +497,13 @@ def test_float8_tensor_slicing_basic(self, granularity):
471
497
472
498
# Create and quantize a model
473
499
model = torch .nn .Linear (64 , 32 , bias = False ).to (device ).to (dtype )
500
+ fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig .VERSION
501
+ Float8DynamicActivationFloat8WeightConfig .VERSION = 1
474
502
quantize_ (
475
- model , Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
503
+ model ,
504
+ Float8DynamicActivationFloat8WeightConfig (granularity = granularity ),
476
505
)
506
+ Float8DynamicActivationFloat8WeightConfig .VERSION = fp8_dq_cur_version
477
507
478
508
weight_impl = model .weight .original_weight_tensor .tensor_impl
479
509
@@ -505,9 +535,13 @@ def test_float8_tensor_slicing_per_tensor(self):
505
535
506
536
# Create and quantize with per-tensor granularity
507
537
model = torch .nn .Linear (64 , 32 , bias = False ).to (device ).to (dtype )
538
+ fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig .VERSION
539
+ Float8DynamicActivationFloat8WeightConfig .VERSION = 1
508
540
quantize_ (
509
- model , Float8DynamicActivationFloat8WeightConfig (granularity = PerTensor ())
541
+ model ,
542
+ Float8DynamicActivationFloat8WeightConfig (granularity = PerTensor ()),
510
543
)
544
+ Float8DynamicActivationFloat8WeightConfig .VERSION = fp8_dq_cur_version
511
545
512
546
original_weight = model .weight
513
547
original_impl = original_weight .original_weight_tensor .tensor_impl
@@ -536,9 +570,13 @@ def test_float8_tensor_slicing_per_row(self):
536
570
537
571
# Create and quantize with per-row granularity
538
572
model = torch .nn .Linear (64 , 32 , bias = False ).to (device ).to (dtype )
573
+ fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig .VERSION
574
+ Float8DynamicActivationFloat8WeightConfig .VERSION = 1
539
575
quantize_ (
540
- model , Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
576
+ model ,
577
+ Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ()),
541
578
)
579
+ Float8DynamicActivationFloat8WeightConfig .VERSION = fp8_dq_cur_version
542
580
543
581
original_weight = model .weight # Shape: (32, 64)
544
582
original_impl = original_weight .original_weight_tensor .tensor_impl
@@ -574,9 +612,13 @@ def test_float8_tensor_slicing_edge_cases(self):
574
612
575
613
# Create and quantize a model
576
614
model = torch .nn .Linear (64 , 32 , bias = False ).to (device ).to (dtype )
615
+ fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig .VERSION
616
+ Float8DynamicActivationFloat8WeightConfig .VERSION = 1
577
617
quantize_ (
578
- model , Float8DynamicActivationFloat8WeightConfig (granularity = PerTensor ())
618
+ model ,
619
+ Float8DynamicActivationFloat8WeightConfig (granularity = PerTensor ()),
579
620
)
621
+ Float8DynamicActivationFloat8WeightConfig .VERSION = fp8_dq_cur_version
580
622
581
623
original_weight = model .weight
582
624
@@ -611,10 +653,13 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
611
653
torch .nn .Linear (64 , 48 , bias = False ).to (device ).to (dtype )
612
654
) # 48 is divisible by 16
613
655
quant_model = copy .deepcopy (ref_model )
656
+ fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig .VERSION
657
+ Float8DynamicActivationFloat8WeightConfig .VERSION = 1
614
658
quantize_ (
615
659
quant_model ,
616
660
Float8DynamicActivationFloat8WeightConfig (granularity = granularity ),
617
661
)
662
+ Float8DynamicActivationFloat8WeightConfig .VERSION = fp8_dq_cur_version
618
663
619
664
# Create input with batch size that works well with slicing
620
665
input_tensor = torch .randn (8 , 64 , device = device , dtype = dtype )
@@ -720,6 +765,7 @@ def test_preprocess_scale_3d_reshape(self):
720
765
self .assertEqual (result .shape , expected_shape )
721
766
722
767
@torch .no_grad ()
768
+ @unittest .skip ("test is flaky in CI, will turn on a bit later" )
723
769
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
724
770
@unittest .skipIf (
725
771
not is_sm_at_least_90 (), "Requires GPU with compute capability >= 9.0"
@@ -743,7 +789,14 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
743
789
m = torch .nn .Sequential (
744
790
torch .nn .Linear (K , N , device = "cuda" , dtype = torch .bfloat16 )
745
791
)
746
- quantize_ (m , Float8DynamicActivationFloat8WeightConfig (granularity = granularity ))
792
+ fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig .VERSION
793
+ Float8DynamicActivationFloat8WeightConfig .VERSION = 1
794
+ quantize_ (
795
+ m ,
796
+ Float8DynamicActivationFloat8WeightConfig (granularity = granularity ),
797
+ )
798
+ Float8DynamicActivationFloat8WeightConfig .VERSION = fp8_dq_cur_version
799
+
747
800
m = torch .compile (m , mode = torch_compile_mode )
748
801
x = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
749
802
0 commit comments