Skip to content

Commit 615877d

Browse files
authored
Replace export_for_training with torch.export.export (#2724)
* Deprecate old TORCH_VERSION variables **Summary:** This commit deprecates the following variables: ``` TORCH_VERSION_AT_LEAST_2_5 TORCH_VERSION_AT_LEAST_2_4 TORCH_VERSION_AT_LEAST_2_3 TORCH_VERSION_AT_LEAST_2_2 TORCH_VERSION_AFTER_2_5 TORCH_VERSION_AFTER_2_4 TORCH_VERSION_AFTER_2_3 TORCH_VERSION_AFTER_2_2 ``` As of this commit, the latest released version of PyTorch is 2.8, which means we can drop support for 2.5 and before since we only support 3 of the latest releases. The next commit will remove usages of all of these variables from within torchao. **Test Plan:** ``` python test/test_utils.py -k torch_version_deprecation ``` [ghstack-poisoned] * Update on "Deprecate old TORCH_VERSION variables" **Summary:** This commit deprecates the following variables: ``` TORCH_VERSION_AT_LEAST_2_5 TORCH_VERSION_AT_LEAST_2_4 TORCH_VERSION_AT_LEAST_2_3 TORCH_VERSION_AT_LEAST_2_2 TORCH_VERSION_AFTER_2_5 TORCH_VERSION_AFTER_2_4 TORCH_VERSION_AFTER_2_3 TORCH_VERSION_AFTER_2_2 ``` As of this commit, the latest released version of PyTorch is 2.8, which means we can drop support for 2.5 and before since we only support 3 of the latest releases. The next commit will remove usages of all of these variables from within torchao. **Test Plan:** ``` python test/test_utils.py -k torch_version_deprecation ``` [ghstack-poisoned] * Drop support for PyTorch 2.5 and before **Summary:** We gate on the PyTorch version throughout the repo. Recently PyTorch 2.8 was released, so the oldest PyTorch version we need to support is 2.6. After this commit, we assume the user is running PyTorch 2.6+, and remove all references to the following variables, which are deprecated. ``` TORCH_VERSION_AT_LEAST_2_6 TORCH_VERSION_AT_LEAST_2_5 TORCH_VERSION_AT_LEAST_2_4 TORCH_VERSION_AT_LEAST_2_3 TORCH_VERSION_AT_LEAST_2_2 TORCH_VERSION_AFTER_2_5 TORCH_VERSION_AFTER_2_4 TORCH_VERSION_AFTER_2_3 TORCH_VERSION_AFTER_2_2 ``` **Test Plan:** CI [ghstack-poisoned] * Remove old `change_linear_weights_to_*` APIs **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] * Update base for Update on "Remove old `change_linear_weights_to_*` APIs" **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] * Update base for Update on "Remove old `change_linear_weights_to_*` APIs" **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] * Update base for Update on "Remove old `change_linear_weights_to_*` APIs" **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] * Update base for Update on "Remove old `change_linear_weights_to_*` APIs" **Summary:** This commit removes these super old quantization APIs that aren't even accessible by the user: ``` change_linear_weights_to_int8_dqtensors change_linear_weights_to_int8_woqtensors change_linear_weights_to_int4_woqtensors ``` **Test Plan:** CI [ghstack-poisoned] * Replace `export_for_training` with `torch.export.export` **Summary:** Bypasses the following deprecation warning: ``` `torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. Please use `torch.export.export` instead, which is functionally equivalent. ``` Bonus: remove some references to `capture_pre_autograd_graph`, which is even older. **Test Plan:** CI [ghstack-poisoned] * Update base for Update on "Replace `export_for_training` with `torch.export.export`" **Summary:** Bypasses the following deprecation warning: ``` `torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. Please use `torch.export.export` instead, which is functionally equivalent. ``` Bonus: remove some references to `capture_pre_autograd_graph`, which is even older. **Test Plan:** CI [ghstack-poisoned] * Update base for Update on "Replace `export_for_training` with `torch.export.export`" **Summary:** Bypasses the following deprecation warning: ``` `torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. Please use `torch.export.export` instead, which is functionally equivalent. ``` Bonus: remove some references to `capture_pre_autograd_graph`, which is even older. **Test Plan:** CI [ghstack-poisoned] * Update base for Update on "Replace `export_for_training` with `torch.export.export`" **Summary:** Bypasses the following deprecation warning: ``` `torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. Please use `torch.export.export` instead, which is functionally equivalent. ``` Bonus: remove some references to `capture_pre_autograd_graph`, which is even older. **Test Plan:** CI [ghstack-poisoned] * Update base for Update on "Replace `export_for_training` with `torch.export.export`" **Summary:** Bypasses the following deprecation warning: ``` `torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. Please use `torch.export.export` instead, which is functionally equivalent. ``` Bonus: remove some references to `capture_pre_autograd_graph`, which is even older. **Test Plan:** CI [ghstack-poisoned] * Update base for Update on "Replace `export_for_training` with `torch.export.export`" **Summary:** Bypasses the following deprecation warning: ``` `torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. Please use `torch.export.export` instead, which is functionally equivalent. ``` Bonus: remove some references to `capture_pre_autograd_graph`, which is even older. **Test Plan:** CI [ghstack-poisoned]
1 parent e79208c commit 615877d

23 files changed

+73
-104
lines changed

docs/source/tutorials_source/pt2e_quant_ptq.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ Here is how you can use ``torch.export`` to export the model:
362362
{0: torch.export.Dim("dim")} if i == 0 else None
363363
for i in range(len(example_inputs))
364364
)
365-
exported_model = torch.export.export_for_training(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()
365+
exported_model = torch.export.export(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()
366366
367367
# for pytorch 2.5 and before
368368
# dynamic_shape API may vary as well
@@ -501,7 +501,7 @@ Now we can compare the size and model accuracy with baseline model.
501501
# Quantized model size and accuracy
502502
print("Size of model after quantization")
503503
# export again to remove unused weights
504-
quantized_model = torch.export.export_for_training(quantized_model, example_inputs).module()
504+
quantized_model = torch.export.export(quantized_model, example_inputs).module()
505505
print_size_of_model(quantized_model)
506506
507507
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)

docs/source/tutorials_source/pt2e_quant_qat.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ to the post training quantization (PTQ) flow for the most part:
1313
.. code:: python
1414
1515
import torch
16-
from torch._export import capture_pre_autograd_graph
1716
from torchao.quantization.pt2e.quantize_pt2e import (
1817
prepare_qat_pt2e,
1918
convert_pt2e,
@@ -434,7 +433,6 @@ prepared. For example:
434433

435434
.. code:: python
436435
437-
from torch._export import capture_pre_autograd_graph
438436
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
439437
get_symmetric_quantization_config,
440438
XNNPACKQuantizer,
@@ -443,7 +441,7 @@ prepared. For example:
443441
444442
example_inputs = (torch.rand(2, 3, 224, 224),)
445443
float_model = resnet18(pretrained=False)
446-
exported_model = capture_pre_autograd_graph(float_model, example_inputs)
444+
exported_model = torch.export.export(float_model, example_inputs).module()
447445
quantizer = XNNPACKQuantizer()
448446
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
449447
prepared_model = prepare_qat_pt2e(exported_model, quantizer)

docs/source/tutorials_source/pt2e_quant_x86_inductor.rst

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ We will start by performing the necessary imports, capturing the FX Graph from t
105105
exported_model = export(
106106
model,
107107
example_inputs
108-
)
108+
).module()
109109

110110

111111
Next, we will have the FX Module to be quantized.
@@ -243,12 +243,10 @@ The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:
243243
.. code:: python
244244
245245
import torch
246-
from torch._export import capture_pre_autograd_graph
247246
from torchao.quantization.pt2e.quantize_pt2e import (
248247
prepare_qat_pt2e,
249248
convert_pt2e,
250249
)
251-
from torch.export import export
252250
import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
253251
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import X86InductorQuantizer
254252
@@ -264,9 +262,7 @@ The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:
264262
m = M()
265263
266264
# Step 1. program capture
267-
# NOTE: this API will be updated to torch.export API in the future, but the captured
268-
# result shoud mostly stay the same
269-
exported_model = export(m, example_inputs)
265+
exported_model = torch.export.export(m, example_inputs).module()
270266
# we get a model with aten ops
271267
272268
# Step 2. quantization-aware training

examples/sam2_amg_server/compile_export_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,7 @@ def aot_compile(
118118
"max_autotune": True,
119119
"triton.cudagraphs": True,
120120
}
121-
122-
from torch.export import export_for_training
123-
124-
exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
121+
exported = torch.export.export(fn, sample_args, sample_kwargs, strict=True)
125122
exported.run_decompositions()
126123
output_path = torch._inductor.aoti_compile_and_package(
127124
exported,

examples/sam2_vos_example/compile_export_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,7 @@ def aot_compile(
8181
"max_autotune": True,
8282
"triton.cudagraphs": True,
8383
}
84-
85-
from torch.export import export_for_training
86-
87-
exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
84+
exported = torch.export.export(fn, sample_args, sample_kwargs, strict=True)
8885
exported.run_decompositions()
8986
output_path = torch._inductor.aoti_compile_and_package(
9087
exported,

test/dtypes/test_uint4.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,7 @@ def forward(self, x):
242242

243243
# program capture
244244
m = copy.deepcopy(m_eager)
245-
m = torch.export.texport_for_training(
246-
m,
247-
example_inputs,
248-
).module()
245+
m = torch.export.export(m, example_inputs).module()
249246

250247
m = prepare_pt2e(m, quantizer)
251248
# Calibrate

test/integration/test_integration.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1953,9 +1953,7 @@ def forward(self, x):
19531953
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
19541954
# we can re-enable this after non-functional IR is enabled in export
19551955
# model = torch.export.export(model, example_inputs).module()
1956-
model = torch.export.export_for_training(
1957-
model, example_inputs, strict=True
1958-
).module()
1956+
model = torch.export.export(model, example_inputs, strict=True).module()
19591957
after_export = model(x)
19601958
self.assertTrue(torch.equal(after_export, ref))
19611959
if api is _int8da_int4w_api:

test/prototype/inductor/test_int8_sdpa_fusion.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,6 @@ def _check_common(
157157
)
158158
@config.patch({"freezing": True})
159159
def _test_sdpa_int8_rewriter(self):
160-
from torch.export import export_for_training
161-
162160
import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
163161
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
164162
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import (
@@ -199,11 +197,7 @@ def _test_sdpa_int8_rewriter(self):
199197
quantizer.set_function_type_qconfig(
200198
torch.matmul, quantizer.get_global_quantization_config()
201199
)
202-
export_model = export_for_training(
203-
mod,
204-
inputs,
205-
strict=True,
206-
).module()
200+
export_model = torch.export.export(mod, inputs, strict=True).module()
207201
prepare_model = prepare_pt2e(export_model, quantizer)
208202
prepare_model(*inputs)
209203
convert_model = convert_pt2e(prepare_model)

test/quantization/pt2e/test_arm_inductor_quantizer.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import torch
1616
import torch.nn as nn
17-
from torch.export import export_for_training
1817
from torch.testing._internal.common_quantization import (
1918
NodeSpec as ns,
2019
)
@@ -315,10 +314,7 @@ def _test_quantizer(
315314

316315
# program capture
317316
m = copy.deepcopy(m_eager)
318-
m = export_for_training(
319-
m,
320-
example_inputs,
321-
).module()
317+
m = torch.export.export(m, example_inputs).module()
322318

323319
# QAT Model failed to deepcopy
324320
export_model = m if is_qat else copy.deepcopy(m)
@@ -576,7 +572,7 @@ def _test_linear_unary_helper(
576572
Test pattern of linear with unary post ops (e.g. relu) with ArmInductorQuantizer.
577573
"""
578574
use_bias_list = [True, False]
579-
# TODO test for inplace add after refactoring of export_for_training
575+
# TODO test for inplace add after refactoring of export
580576
inplace_list = [False]
581577
if post_op_algo_list is None:
582578
post_op_algo_list = [None]
@@ -716,7 +712,7 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False):
716712
Currently, only add as binary post op is supported.
717713
"""
718714
linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
719-
# TODO test for inplace add after refactoring of export_for_training
715+
# TODO test for inplace add after refactoring of export
720716
inplace_add_list = [False]
721717
example_inputs = (torch.randn(2, 16),)
722718
quantizer = ArmInductorQuantizer().set_global(
@@ -1078,7 +1074,7 @@ def forward(self, x):
10781074
)
10791075
example_inputs = (torch.randn(2, 2),)
10801076
m = M().eval()
1081-
m = export_for_training(m, example_inputs).module()
1077+
m = torch.export.export(m, example_inputs).module()
10821078
m = prepare_pt2e(m, quantizer)
10831079
# Use a linear count instead of names because the names might change, but
10841080
# the order should be the same.

test/quantization/pt2e/test_duplicate_dq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _test_duplicate_dq(
110110

111111
# program capture
112112
m = copy.deepcopy(m_eager)
113-
m = export_for_training(m, example_inputs, strict=True).module()
113+
m = torch.export.export(m, example_inputs, strict=True).module()
114114

115115
m = prepare_pt2e(m, quantizer)
116116
# Calibrate

0 commit comments

Comments
 (0)