Skip to content

Replace export_for_training with torch.export.export #2724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e1d7de3
Deprecate old TORCH_VERSION variables
andrewor14 Aug 8, 2025
922fc3e
Update on "Deprecate old TORCH_VERSION variables"
andrewor14 Aug 8, 2025
fc7dffe
Drop support for PyTorch 2.5 and before
andrewor14 Aug 8, 2025
42f9081
Remove old `change_linear_weights_to_*` APIs
andrewor14 Aug 8, 2025
83fb739
Update base for Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
afedb9f
Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
4697b22
Update base for Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
da64318
Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
ac6c78f
Update base for Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
4d93ac7
Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
d6c4715
Update base for Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
d4762be
Update on "Remove old `change_linear_weights_to_*` APIs"
andrewor14 Aug 8, 2025
2bc14bd
Replace `export_for_training` with `torch.export.export`
andrewor14 Aug 8, 2025
e87d7b2
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 8, 2025
8a8843f
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 8, 2025
827d81b
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 12, 2025
10066ca
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 12, 2025
a15f8fa
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 12, 2025
b3cf7b4
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 12, 2025
a5f8040
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 12, 2025
7792da8
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 12, 2025
8caf0a5
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 12, 2025
c2ffa16
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 12, 2025
eda2df2
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 12, 2025
1f0de23
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 12, 2025
69c6c34
Update base for Update on "Replace `export_for_training` with `torch.…
andrewor14 Aug 13, 2025
eed1c7c
Update on "Replace `export_for_training` with `torch.export.export`"
andrewor14 Aug 13, 2025
d687491
Merge branch 'main' into gh/andrewor14/22/head
andrewor14 Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/tutorials_source/pt2e_quant_ptq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ Here is how you can use ``torch.export`` to export the model:
{0: torch.export.Dim("dim")} if i == 0 else None
for i in range(len(example_inputs))
)
exported_model = torch.export.export_for_training(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()
exported_model = torch.export.export(model_to_quantize, example_inputs, dynamic_shapes=dynamic_shapes).module()
# for pytorch 2.5 and before
# dynamic_shape API may vary as well
Expand Down Expand Up @@ -501,7 +501,7 @@ Now we can compare the size and model accuracy with baseline model.
# Quantized model size and accuracy
print("Size of model after quantization")
# export again to remove unused weights
quantized_model = torch.export.export_for_training(quantized_model, example_inputs).module()
quantized_model = torch.export.export(quantized_model, example_inputs).module()
print_size_of_model(quantized_model)
top1, top5 = evaluate(quantized_model, criterion, data_loader_test)
Expand Down
4 changes: 1 addition & 3 deletions docs/source/tutorials_source/pt2e_quant_qat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ to the post training quantization (PTQ) flow for the most part:
.. code:: python

import torch
from torch._export import capture_pre_autograd_graph
from torchao.quantization.pt2e.quantize_pt2e import (
prepare_qat_pt2e,
convert_pt2e,
Expand Down Expand Up @@ -434,7 +433,6 @@ prepared. For example:

.. code:: python

from torch._export import capture_pre_autograd_graph
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
Expand All @@ -443,7 +441,7 @@ prepared. For example:

example_inputs = (torch.rand(2, 3, 224, 224),)
float_model = resnet18(pretrained=False)
exported_model = capture_pre_autograd_graph(float_model, example_inputs)
exported_model = torch.export.export(float_model, example_inputs).module()
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
prepared_model = prepare_qat_pt2e(exported_model, quantizer)
Expand Down
8 changes: 2 additions & 6 deletions docs/source/tutorials_source/pt2e_quant_x86_inductor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ We will start by performing the necessary imports, capturing the FX Graph from t
exported_model = export(
model,
example_inputs
)
).module()


Next, we will have the FX Module to be quantized.
Expand Down Expand Up @@ -243,12 +243,10 @@ The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:
.. code:: python

import torch
from torch._export import capture_pre_autograd_graph
from torchao.quantization.pt2e.quantize_pt2e import (
prepare_qat_pt2e,
convert_pt2e,
)
from torch.export import export
import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import X86InductorQuantizer

Expand All @@ -264,9 +262,7 @@ The PyTorch 2 Export QAT flow is largely similar to the PTQ flow:
m = M()

# Step 1. program capture
# NOTE: this API will be updated to torch.export API in the future, but the captured
# result shoud mostly stay the same
exported_model = export(m, example_inputs)
exported_model = torch.export.export(m, example_inputs).module()
# we get a model with aten ops

# Step 2. quantization-aware training
Expand Down
5 changes: 1 addition & 4 deletions examples/sam2_amg_server/compile_export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,7 @@ def aot_compile(
"max_autotune": True,
"triton.cudagraphs": True,
}

from torch.export import export_for_training

exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
exported = torch.export.export(fn, sample_args, sample_kwargs, strict=True)
exported.run_decompositions()
output_path = torch._inductor.aoti_compile_and_package(
exported,
Expand Down
5 changes: 1 addition & 4 deletions examples/sam2_vos_example/compile_export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ def aot_compile(
"max_autotune": True,
"triton.cudagraphs": True,
}

from torch.export import export_for_training

exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
exported = torch.export.export(fn, sample_args, sample_kwargs, strict=True)
exported.run_decompositions()
output_path = torch._inductor.aoti_compile_and_package(
exported,
Expand Down
5 changes: 1 addition & 4 deletions test/dtypes/test_uint4.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,7 @@ def forward(self, x):

# program capture
m = copy.deepcopy(m_eager)
m = torch.export.texport_for_training(
m,
example_inputs,
).module()
m = torch.export.export(m, example_inputs).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
Expand Down
4 changes: 1 addition & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,9 +1953,7 @@ def forward(self, x):
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
# we can re-enable this after non-functional IR is enabled in export
# model = torch.export.export(model, example_inputs).module()
model = torch.export.export_for_training(
model, example_inputs, strict=True
).module()
model = torch.export.export(model, example_inputs, strict=True).module()
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))
if api is _int8da_int4w_api:
Expand Down
8 changes: 1 addition & 7 deletions test/prototype/inductor/test_int8_sdpa_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ def _check_common(
)
@config.patch({"freezing": True})
def _test_sdpa_int8_rewriter(self):
from torch.export import export_for_training

import torchao.quantization.pt2e.quantizer.x86_inductor_quantizer as xiq
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import (
Expand Down Expand Up @@ -199,11 +197,7 @@ def _test_sdpa_int8_rewriter(self):
quantizer.set_function_type_qconfig(
torch.matmul, quantizer.get_global_quantization_config()
)
export_model = export_for_training(
mod,
inputs,
strict=True,
).module()
export_model = torch.export.export(mod, inputs, strict=True).module()
prepare_model = prepare_pt2e(export_model, quantizer)
prepare_model(*inputs)
convert_model = convert_pt2e(prepare_model)
Expand Down
12 changes: 4 additions & 8 deletions test/quantization/pt2e/test_arm_inductor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import torch
import torch.nn as nn
from torch.export import export_for_training
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
)
Expand Down Expand Up @@ -315,10 +314,7 @@ def _test_quantizer(

# program capture
m = copy.deepcopy(m_eager)
m = export_for_training(
m,
example_inputs,
).module()
m = torch.export.export(m, example_inputs).module()

# QAT Model failed to deepcopy
export_model = m if is_qat else copy.deepcopy(m)
Expand Down Expand Up @@ -576,7 +572,7 @@ def _test_linear_unary_helper(
Test pattern of linear with unary post ops (e.g. relu) with ArmInductorQuantizer.
"""
use_bias_list = [True, False]
# TODO test for inplace add after refactoring of export_for_training
# TODO test for inplace add after refactoring of export
inplace_list = [False]
if post_op_algo_list is None:
post_op_algo_list = [None]
Expand Down Expand Up @@ -716,7 +712,7 @@ def _test_linear_binary_helper(self, is_qat=False, is_dynamic=False):
Currently, only add as binary post op is supported.
"""
linear_pos_list = [NodePosType.left, NodePosType.right, NodePosType.both]
# TODO test for inplace add after refactoring of export_for_training
# TODO test for inplace add after refactoring of export
inplace_add_list = [False]
example_inputs = (torch.randn(2, 16),)
quantizer = ArmInductorQuantizer().set_global(
Expand Down Expand Up @@ -1078,7 +1074,7 @@ def forward(self, x):
)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = export_for_training(m, example_inputs).module()
m = torch.export.export(m, example_inputs).module()
m = prepare_pt2e(m, quantizer)
# Use a linear count instead of names because the names might change, but
# the order should be the same.
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/pt2e/test_duplicate_dq.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _test_duplicate_dq(

# program capture
m = copy.deepcopy(m_eager)
m = export_for_training(m, example_inputs, strict=True).module()
m = torch.export.export(m, example_inputs, strict=True).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/pt2e/test_metadata_porting.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _test_metadata_porting(

# program capture
m = copy.deepcopy(m_eager)
m = torch.export.export_for_training(m, example_inputs, strict=True).module()
m = torch.export.export(m, example_inputs, strict=True).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
Expand Down
21 changes: 9 additions & 12 deletions test/quantization/pt2e/test_numeric_debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@
from torchao.testing.pt2e.utils import PT2ENumericDebuggerTestCase
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8

if TORCH_VERSION_AT_LEAST_2_8:
from torch.export import export_for_training

# Increase cache size limit to avoid FailOnRecompileLimitHit error when running multiple tests
# that use export_for_training, which causes many dynamo recompilations
# that use torch.export.export, which causes many dynamo recompilations
if TORCH_VERSION_AT_LEAST_2_8:
torch._dynamo.config.cache_size_limit = 128

Expand All @@ -37,7 +34,7 @@ class TestNumericDebuggerInfra(PT2ENumericDebuggerTestCase):
def test_simple(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
m = ep.module()
self._assert_each_node_has_from_node_source(m)
from_node_source_map = self._extract_from_node_source(m)
Expand All @@ -50,7 +47,7 @@ def test_simple(self):
def test_control_flow(self):
m = TestHelperModules.ControlFlow()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
m = ep.module()

self._assert_each_node_has_from_node_source(m)
Expand Down Expand Up @@ -93,13 +90,13 @@ def test_deepcopy_preserve_handle(self):
def test_re_export_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
m = ep.module()

self._assert_each_node_has_from_node_source(m)
from_node_source_map_ref = self._extract_from_node_source(m)

ep_reexport = export_for_training(m, example_inputs, strict=True)
ep_reexport = torch.export.export(m, example_inputs, strict=True)
m_reexport = ep_reexport.module()

self._assert_each_node_has_from_node_source(m_reexport)
Expand All @@ -110,7 +107,7 @@ def test_re_export_preserve_handle(self):
def test_run_decompositions_same_handle_id(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
m = ep.module()

self._assert_each_node_has_from_node_source(m)
Expand All @@ -136,7 +133,7 @@ def test_run_decompositions_map_handle_to_new_nodes(self):

for m in test_models:
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
m = ep.module()

self._assert_each_node_has_from_node_source(m)
Expand All @@ -161,7 +158,7 @@ def test_run_decompositions_map_handle_to_new_nodes(self):
def test_prepare_for_propagation_comparison(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
m = ep.module()
m_logger = prepare_for_propagation_comparison(m)
ref = m(*example_inputs)
Expand All @@ -177,7 +174,7 @@ def test_prepare_for_propagation_comparison(self):
def test_added_node_gets_unique_id(self) -> None:
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)

ref_from_node_source = self._extract_from_node_source(ep.module())
ref_counter = Counter(ref_from_node_source.values())
Expand Down
Loading
Loading