Skip to content

Commit 5eb1e03

Browse files
committed
Add test
Signed-off-by: cyy <[email protected]>
1 parent 7a3f542 commit 5eb1e03

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/test_functional.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
117117
for i in range(iters):
118118
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
119119
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
120+
if i == 0:
121+
d = S.as_dict()
122+
S = F.QuantState.from_dict(d, device=torch.device(device))
120123
A2 = F.dequantize_blockwise(C, S)
121124
diff = torch.abs(A1 - A2).float()
122125
reldiff = diff / torch.abs(A1.float() + 1e-8)
@@ -133,6 +136,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
133136
for i in range(iters):
134137
A1 = torch.rand(1024, 1024, device=device, dtype=dtype)
135138
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
139+
if i == 0:
140+
d = S.as_dict()
141+
S = F.QuantState.from_dict(d, device=torch.device(device))
136142
A2 = F.dequantize_blockwise(C, S)
137143
diff = torch.abs(A1 - A2).float()
138144
reldiff = diff / torch.abs(A1.float() + 1e-8)
@@ -242,6 +248,9 @@ def test_fp8_quant(self, device):
242248
for i in range(100):
243249
A1 = torch.randn(1024, 1024, device=device)
244250
C, SC = F.quantize_blockwise(A1, code=code)
251+
if i == 0:
252+
d = SC.as_dict()
253+
SC = F.QuantState.from_dict(d, device=torch.device(device))
245254
A2 = F.dequantize_blockwise(C, SC)
246255
diff = torch.abs(A1 - A2)
247256
reldiff = diff / torch.abs(A1 + 1e-8)
@@ -1115,6 +1124,8 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11151124

11161125
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
11171126
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
1127+
d = SA.as_dict()
1128+
SA = F.QuantState.from_dict(d, device=torch.device(device))
11181129
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
11191130

11201131
err = (A1 - A2).abs().float()

0 commit comments

Comments
 (0)