@@ -117,6 +117,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
117
117
for i in range (iters ):
118
118
A1 = torch .randn (1024 , 1024 , device = device , dtype = dtype )
119
119
C , S = F .quantize_blockwise (A1 , blocksize = blocksize , nested = nested )
120
+ if i == 0 :
121
+ d = S .as_dict ()
122
+ S = F .QuantState .from_dict (d , device = torch .device (device ))
120
123
A2 = F .dequantize_blockwise (C , S )
121
124
diff = torch .abs (A1 - A2 ).float ()
122
125
reldiff = diff / torch .abs (A1 .float () + 1e-8 )
@@ -133,6 +136,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
133
136
for i in range (iters ):
134
137
A1 = torch .rand (1024 , 1024 , device = device , dtype = dtype )
135
138
C , S = F .quantize_blockwise (A1 , blocksize = blocksize , nested = nested , code = code )
139
+ if i == 0 :
140
+ d = S .as_dict ()
141
+ S = F .QuantState .from_dict (d , device = torch .device (device ))
136
142
A2 = F .dequantize_blockwise (C , S )
137
143
diff = torch .abs (A1 - A2 ).float ()
138
144
reldiff = diff / torch .abs (A1 .float () + 1e-8 )
@@ -242,6 +248,9 @@ def test_fp8_quant(self, device):
242
248
for i in range (100 ):
243
249
A1 = torch .randn (1024 , 1024 , device = device )
244
250
C , SC = F .quantize_blockwise (A1 , code = code )
251
+ if i == 0 :
252
+ d = SC .as_dict ()
253
+ SC = F .QuantState .from_dict (d , device = torch .device (device ))
245
254
A2 = F .dequantize_blockwise (C , SC )
246
255
diff = torch .abs (A1 - A2 )
247
256
reldiff = diff / torch .abs (A1 + 1e-8 )
@@ -1115,6 +1124,8 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
1115
1124
1116
1125
A1 = torch .randn (1024 , 1024 , device = device , dtype = dtype )
1117
1126
qa , SA = F .quantize_4bit (A1 , blocksize = blocksize , quant_type = quant_type )
1127
+ d = SA .as_dict ()
1128
+ SA = F .QuantState .from_dict (d , device = torch .device (device ))
1118
1129
A2 = F .dequantize_4bit (qa , SA , blocksize = blocksize , quant_type = quant_type )
1119
1130
1120
1131
err = (A1 - A2 ).abs ().float ()
0 commit comments