Skip to content

Commit b844dbd

Browse files
committed
Bump version for float8 dynamic quant and weight only quant configs
Summary: This PR changes the default VERSION for Float8DynamicActivationFloat8WeightConfig and Float8WeightOnlyConfig from 1 to 2 and makes the VERSION 1 config and VERSION 1 quantized models deprecated, more details in: #2649 Also extended current config serialization to work with multiple config versions Deprecation Note: ``` from transformers import AutoModelForCausalLM, AutoTokenizer model_name = "torchao-testing/opt-125m-float8dq-row-v1-0.13-dev" quantized_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="bfloat16", device_map="cuda", ) /data/users/jerryzh/ao/torchao/core/config.py:249: UserWarning: Stored version is not the same as current default version of the config: stored_version=1, current_version=2, please check the deprecation warning warnings.warn( /data/users/jerryzh/ao/torchao/dtypes/floatx/float8_layout.py:113: UserWarning: Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see #2649 for more details warnings.warn( ``` Suggestion: upgrade torchao to 0.13 and later and generate the checkpoint again: ``` quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) ``` Or download the checkpoint again (please let us know if the checkpoint is not updated) Test Plan: tested with serializing a model with VERSION 1 config and load it, and checks warnings are properly printed ``` python test/integration/test_loading_deprecated_checkpoint.py ``` Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2650, branch: jerryzh168/stack/14
1 parent 3b4bc98 commit b844dbd

File tree

8 files changed

+190
-78
lines changed

8 files changed

+190
-78
lines changed

test/core/test_config.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import os
99
import tempfile
10+
import warnings
1011
from dataclasses import dataclass
1112
from unittest import mock
1213

@@ -15,7 +16,6 @@
1516

1617
from torchao.core.config import (
1718
AOBaseConfig,
18-
VersionMismatchError,
1919
config_from_dict,
2020
config_to_dict,
2121
)
@@ -176,7 +176,7 @@ def test_disallowed_modules():
176176

177177

178178
def test_version_mismatch():
179-
"""Test that version mismatch raises an error during reconstruction."""
179+
"""Test that version mismatch prints a warning during reconstruction."""
180180
# Create a config
181181
dummy_config = DummyNonAllowedConfig()
182182
reconstructable = config_to_dict(dummy_config)
@@ -186,11 +186,13 @@ def test_version_mismatch():
186186

187187
# Patch to allow the module but should still fail due to version mismatch
188188
with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}):
189-
with pytest.raises(
190-
VersionMismatchError,
191-
match="Version mismatch for DummyNonAllowedConfig: stored version 1 != current version 2",
192-
):
189+
with warnings.catch_warnings(record=True) as caught_warnings:
193190
config_from_dict(reconstructable)
191+
assert any(
192+
"Stored version is not the same as current default version of the config"
193+
in str(w.message)
194+
for w in caught_warnings
195+
), "Didn't get expected warning message for version mismatch"
194196

195197

196198
def test_default_version():

test/dtypes/test_affine_quantized_float.py

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,14 @@
3030
from torchao.float8.float8_utils import compute_error
3131
from torchao.quantization import (
3232
Float8DynamicActivationFloat8WeightConfig,
33-
float8_dynamic_activation_float8_weight,
34-
float8_weight_only,
33+
Float8StaticActivationFloat8WeightConfig,
34+
Float8WeightOnlyConfig,
3535
quantize_,
3636
)
3737
from torchao.quantization.granularity import (
3838
PerRow,
3939
PerTensor,
4040
)
41-
from torchao.quantization.quant_api import (
42-
float8_static_activation_float8_weight,
43-
)
4441
from torchao.quantization.quant_primitives import (
4542
MappingType,
4643
_choose_scale_float8,
@@ -117,17 +114,24 @@ def test_fp8_linear_variants(
117114
torch.float8_e4m3fn,
118115
scale_dtype=torch.float32,
119116
)
117+
fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig.VERSION
118+
fp8wo_cur_version = Float8WeightOnlyConfig.VERSION
119+
Float8DynamicActivationFloat8WeightConfig.VERSION = 1
120+
Float8WeightOnlyConfig.VERSION = 1
120121
mode_map = {
121122
"dynamic": partial(
122-
float8_dynamic_activation_float8_weight, granularity=granularity
123+
Float8DynamicActivationFloat8WeightConfig,
124+
granularity=granularity,
123125
),
124-
"weight-only": float8_weight_only,
126+
"weight-only": Float8WeightOnlyConfig,
125127
"static": partial(
126-
float8_static_activation_float8_weight,
128+
Float8StaticActivationFloat8WeightConfig,
127129
scale=scale,
128130
granularity=granularity,
129131
),
130132
}
133+
Float8DynamicActivationFloat8WeightConfig.VERSION = fp8_dq_cur_version
134+
Float8WeightOnlyConfig.VERSION = fp8wo_cur_version
131135

132136
# Create a linear layer with bfloat16 dtype
133137
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
@@ -152,7 +156,7 @@ def test_fp8_linear_variants(
152156
)
153157
def test_invalid_granularity(self):
154158
with pytest.raises(ValueError, match="Invalid granularity specification"):
155-
float8_dynamic_activation_float8_weight(granularity="invalid")
159+
Float8DynamicActivationFloat8WeightConfig(granularity="invalid")
156160

157161
@unittest.skipIf(
158162
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
@@ -162,7 +166,9 @@ def test_mismatched_granularity(self):
162166
ValueError,
163167
match="Different granularities for activation and weight are not supported",
164168
):
165-
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
169+
Float8DynamicActivationFloat8WeightConfig(
170+
granularity=(PerTensor(), PerRow())
171+
)
166172

167173
@unittest.skipIf(
168174
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
@@ -172,8 +178,8 @@ class UnsupportedGranularity:
172178
pass
173179

174180
with pytest.raises(ValueError, match="Invalid granularity types"):
175-
float8_dynamic_activation_float8_weight(
176-
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
181+
Float8DynamicActivationFloat8WeightConfig(
182+
granularity=(UnsupportedGranularity(), UnsupportedGranularity()),
177183
)
178184

179185
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -187,7 +193,8 @@ def test_per_row_with_float32(self):
187193
):
188194
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
189195
quantize_(
190-
model, float8_dynamic_activation_float8_weight(granularity=PerRow())
196+
model,
197+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
191198
)
192199

193200
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -199,19 +206,26 @@ def test_serialization(self, mode: str):
199206
# Create and quantize the model
200207
model = ToyLinearModel(16, 32).to(device="cuda")
201208

209+
fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig.VERSION
210+
fp8wo_cur_version = Float8WeightOnlyConfig.VERSION
211+
Float8DynamicActivationFloat8WeightConfig.VERSION = 1
212+
Float8WeightOnlyConfig.VERSION = 1
202213
mode_map = {
203214
"dynamic": partial(
204-
float8_dynamic_activation_float8_weight, granularity=PerTensor()
215+
Float8DynamicActivationFloat8WeightConfig,
216+
granularity=PerTensor(),
205217
),
206-
"weight-only": float8_weight_only,
218+
"weight-only": Float8WeightOnlyConfig,
207219
"static": partial(
208-
float8_static_activation_float8_weight,
220+
Float8StaticActivationFloat8WeightConfig,
209221
scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"),
210222
granularity=PerTensor(),
211223
),
212224
}
225+
213226
factory = mode_map[mode]()
214227
quantize_(model, factory)
228+
print("model:", model)
215229

216230
# Save the state dict to an in-memory buffer
217231
buffer = io.BytesIO()
@@ -262,6 +276,10 @@ def test_serialization(self, mode: str):
262276
original_layer.weight.scale, new_layer.weight.scale
263277
), f"Scales do not match for {layer_name}"
264278

279+
# restore in the end
280+
Float8DynamicActivationFloat8WeightConfig.VERSION = fp8_dq_cur_version
281+
Float8WeightOnlyConfig.VERSION = fp8wo_cur_version
282+
265283
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
266284
@unittest.skipIf(
267285
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
@@ -274,9 +292,13 @@ def test_fp8_weight_dimension_warning(self):
274292
with self.assertLogs(
275293
"torchao.quantization.quant_api", level="INFO"
276294
) as log_context:
295+
fp8wo_cur_version = Float8WeightOnlyConfig.VERSION
296+
Float8DynamicActivationFloat8WeightConfig.VERSION = 1
277297
quantize_(
278-
model, float8_dynamic_activation_float8_weight(granularity=PerTensor())
298+
model,
299+
Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()),
279300
)
301+
Float8DynamicActivationFloat8WeightConfig.VERSION = fp8wo_cur_version
280302
print(model)
281303

282304
# Verify warning messages for both layers
@@ -319,9 +341,13 @@ def test_mm_float8dq_per_row(
319341
torch.nn.Linear(in_features, out_features, bias=bias).to(device).to(dtype)
320342
)
321343
test_linear = copy.deepcopy(ref_linear)
344+
fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig.VERSION
345+
Float8DynamicActivationFloat8WeightConfig.VERSION = 1
322346
quantize_(
323-
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
347+
test_linear,
348+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
324349
)
350+
Float8DynamicActivationFloat8WeightConfig.VERSION = fp8_dq_cur_version
325351

326352
quant_weight = test_linear.weight
327353

@@ -471,9 +497,13 @@ def test_float8_tensor_slicing_basic(self, granularity):
471497

472498
# Create and quantize a model
473499
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
500+
fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig.VERSION
501+
Float8DynamicActivationFloat8WeightConfig.VERSION = 1
474502
quantize_(
475-
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
503+
model,
504+
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
476505
)
506+
Float8DynamicActivationFloat8WeightConfig.VERSION = fp8_dq_cur_version
477507

478508
weight_impl = model.weight.original_weight_tensor.tensor_impl
479509

@@ -505,9 +535,13 @@ def test_float8_tensor_slicing_per_tensor(self):
505535

506536
# Create and quantize with per-tensor granularity
507537
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
538+
fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig.VERSION
539+
Float8DynamicActivationFloat8WeightConfig.VERSION = 1
508540
quantize_(
509-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
541+
model,
542+
Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()),
510543
)
544+
Float8DynamicActivationFloat8WeightConfig.VERSION = fp8_dq_cur_version
511545

512546
original_weight = model.weight
513547
original_impl = original_weight.original_weight_tensor.tensor_impl
@@ -536,9 +570,13 @@ def test_float8_tensor_slicing_per_row(self):
536570

537571
# Create and quantize with per-row granularity
538572
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
573+
fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig.VERSION
574+
Float8DynamicActivationFloat8WeightConfig.VERSION = 1
539575
quantize_(
540-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
576+
model,
577+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
541578
)
579+
Float8DynamicActivationFloat8WeightConfig.VERSION = fp8_dq_cur_version
542580

543581
original_weight = model.weight # Shape: (32, 64)
544582
original_impl = original_weight.original_weight_tensor.tensor_impl
@@ -574,9 +612,13 @@ def test_float8_tensor_slicing_edge_cases(self):
574612

575613
# Create and quantize a model
576614
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
615+
fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig.VERSION
616+
Float8DynamicActivationFloat8WeightConfig.VERSION = 1
577617
quantize_(
578-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
618+
model,
619+
Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()),
579620
)
621+
Float8DynamicActivationFloat8WeightConfig.VERSION = fp8_dq_cur_version
580622

581623
original_weight = model.weight
582624

@@ -611,10 +653,13 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
611653
torch.nn.Linear(64, 48, bias=False).to(device).to(dtype)
612654
) # 48 is divisible by 16
613655
quant_model = copy.deepcopy(ref_model)
656+
fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig.VERSION
657+
Float8DynamicActivationFloat8WeightConfig.VERSION = 1
614658
quantize_(
615659
quant_model,
616660
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
617661
)
662+
Float8DynamicActivationFloat8WeightConfig.VERSION = fp8_dq_cur_version
618663

619664
# Create input with batch size that works well with slicing
620665
input_tensor = torch.randn(8, 64, device=device, dtype=dtype)
@@ -720,6 +765,7 @@ def test_preprocess_scale_3d_reshape(self):
720765
self.assertEqual(result.shape, expected_shape)
721766

722767
@torch.no_grad()
768+
@unittest.skip("test is flaky in CI, will turn on a bit later")
723769
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
724770
@unittest.skipIf(
725771
not is_sm_at_least_90(), "Requires GPU with compute capability >= 9.0"
@@ -743,7 +789,14 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
743789
m = torch.nn.Sequential(
744790
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
745791
)
746-
quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity=granularity))
792+
fp8_dq_cur_version = Float8DynamicActivationFloat8WeightConfig.VERSION
793+
Float8DynamicActivationFloat8WeightConfig.VERSION = 1
794+
quantize_(
795+
m,
796+
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
797+
)
798+
Float8DynamicActivationFloat8WeightConfig.VERSION = fp8_dq_cur_version
799+
747800
m = torch.compile(m, mode=torch_compile_mode)
748801
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
749802

test/float8/test_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,10 +473,10 @@ def test_quantize(self):
473473
m = nn.Sequential(nn.Linear(32, 32)).cuda()
474474
m = convert_to_float8_training(m)
475475
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
476-
from torchao.quantization.quant_api import float8_weight_only, quantize_
476+
from torchao.quantization import Float8WeightOnlyConfig, quantize_
477477

478-
quantize_(m, float8_weight_only())
479-
assert m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn, (
478+
quantize_(m, Float8WeightOnlyConfig())
479+
assert m[0].weight.qdata.dtype == torch.float8_e4m3fn, (
480480
"Post quantization dtype should be torch.float8_e4m3fn"
481481
)
482482
with torch.no_grad():
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import unittest
7+
import warnings
8+
9+
import torch
10+
from torch.testing._internal import common_utils
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
run_tests,
14+
)
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
17+
from torchao.utils import is_sm_at_least_89
18+
19+
_MODEL_NAMES = [
20+
"torchao-testing/opt-125m-float8dq-row-v1-0.13-dev",
21+
]
22+
23+
24+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
25+
@unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+")
26+
class TestLoadingDeprecatedCheckpoint(TestCase):
27+
@common_utils.parametrize("model_name", _MODEL_NAMES)
28+
def test_load_model_and_run(self, model_name):
29+
"""Test that we print correct warning message when loading a deprecated checkpoint"""
30+
# Load and quantize model
31+
with warnings.catch_warnings(record=True) as caught_warnings:
32+
quantized_model = AutoModelForCausalLM.from_pretrained(
33+
model_name,
34+
torch_dtype="bfloat16",
35+
device_map="cuda",
36+
)
37+
assert any(
38+
"Stored version is not the same as current default version of the config"
39+
in str(w.message)
40+
for w in caught_warnings
41+
), "Didn't get expected warning message for version mismatch"
42+
43+
assert any(
44+
"Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
45+
in str(w.message)
46+
for w in caught_warnings
47+
), "Didn't get expected warning message for deprecation"
48+
49+
tokenizer = AutoTokenizer.from_pretrained(model_name)
50+
prompt = ("Hello, my name is",)
51+
inputs = tokenizer(
52+
prompt,
53+
return_tensors="pt",
54+
).to("cuda")
55+
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
56+
# make sure it runs
57+
_ = tokenizer.batch_decode(
58+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
59+
)
60+
61+
62+
common_utils.instantiate_parametrized_tests(TestLoadingDeprecatedCheckpoint)
63+
64+
if __name__ == "__main__":
65+
run_tests()

0 commit comments

Comments
 (0)