Skip to content

Commit 5fe1ad4

Browse files
pianpwkpobin6
authored andcommitted
[draft export] generate fake outputs when real tensor prop finds mismatches (pytorch#139766)
Currently real tensor tracing raises MetadataMismatchErrors if registered fake kernels don't match the real kernels (e.g. shape, aliasing, dtype, etc.). This adds an option to use fake kernel inference to bypass mismatches - this option defaults to False for real tensor tracing, but is on for draft export. Pull Request resolved: pytorch#139766 Approved by: https://github.com/angelayi, https://github.com/zou3519
1 parent b773831 commit 5fe1ad4

File tree

5 files changed

+394
-165
lines changed

5 files changed

+394
-165
lines changed

test/export/test_draft_export.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# Owner(s): ["oncall: export"]
22
import copy
3+
from typing import List, Tuple
34

45
import torch
5-
from torch.export import Dim
6+
from torch.export import Dim, export
67
from torch.export._draft_export import draft_export, FailureType
78
from torch.testing import FileCheck
89
from torch.testing._internal.common_utils import run_tests, TestCase
910
from torch.testing._internal.torchbind_impls import (
1011
_empty_tensor_queue,
1112
init_torchbind_implementations,
1213
)
14+
from torch.utils._pytree import tree_leaves
1315

1416

1517
class TestDraftExport(TestCase):
@@ -271,6 +273,89 @@ def forward(self, tq, x):
271273
self.assertEqual(tq3.size(), 2)
272274
self.assertEqual(tq.size(), 2)
273275

276+
def test_override_size_and_dtype_mismatched_fake_kernels(self):
277+
class M(torch.nn.Module):
278+
def forward(self, a):
279+
return torch.ops.mylib.foo(a)
280+
281+
@torch.library.custom_op("mylib::foo", mutates_args={})
282+
def foo(a: torch.Tensor) -> List[torch.Tensor]:
283+
x = a * 2
284+
y = a.repeat(2, 2)
285+
z = a.to(torch.bfloat16)
286+
return [x, y, z]
287+
288+
@foo.register_fake
289+
def foo_fake_impl(a):
290+
x = torch.empty_like(a) # good
291+
y = torch.empty_like(a) # size mismatch
292+
z = torch.empty_like(a) # dtype mismatch
293+
return [x, y, z]
294+
295+
mod = M()
296+
inputs = (torch.randn(3, 3),)
297+
with self.assertRaises(RuntimeError):
298+
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
299+
export(mod, inputs)
300+
301+
ep, report = draft_export(mod, inputs)
302+
for ep_out, eager_out in zip(ep.module()(*inputs), mod(*inputs)):
303+
self.assertTrue(torch.allclose(ep_out, eager_out))
304+
self.assertEqual(ep_out.dtype, eager_out.dtype)
305+
306+
self.assertEqual(len(report.failures), 2)
307+
self.assertEqual(
308+
report.failures[0].failure_type, FailureType.MISMATCHED_FAKE_KERNEL
309+
)
310+
self.assertEqual(
311+
report.failures[1].failure_type, FailureType.MISMATCHED_FAKE_KERNEL
312+
)
313+
self.assertEqual(
314+
sorted([f.data["reason"] for f in report.failures]),
315+
[
316+
"Dtypes torch.bfloat16 and torch.float32 are not equal!",
317+
"mismatch between fake value 3 and real value 6 ",
318+
],
319+
)
320+
321+
def test_override_incorrectly_aliasing_kernel(self):
322+
class M(torch.nn.Module):
323+
def forward(self, a):
324+
return torch.ops.mylib.foo(a)
325+
326+
@torch.library.custom_op("mylib::foo", mutates_args={})
327+
def foo(a: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
328+
return a * 2, a + 2
329+
330+
@foo.register_fake
331+
def foo_fake_impl(a):
332+
return a, torch.empty_like(a) # incorrectly aliasing
333+
334+
mod = M()
335+
inputs = (torch.randn(3, 3),)
336+
with self.assertRaisesRegex(
337+
RuntimeError,
338+
"Real tensor propagation found an aliasing mismatch",
339+
):
340+
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
341+
export(mod, inputs)
342+
343+
ep, report = draft_export(mod, inputs)
344+
for ep_out, eager_out in zip(
345+
tree_leaves(ep.module()(*inputs)), tree_leaves(mod(*inputs))
346+
):
347+
self.assertTrue(torch.allclose(ep_out, eager_out))
348+
self.assertEqual(ep_out.dtype, eager_out.dtype)
349+
350+
self.assertEqual(len(report.failures), 1)
351+
self.assertEqual(
352+
report.failures[0].failure_type, FailureType.MISMATCHED_FAKE_KERNEL
353+
)
354+
self.assertTrue(
355+
"Mismatched aliasing spec between fake kernel and real kernel"
356+
in report.failures[0].data["reason"]
357+
)
358+
274359

275360
if __name__ == "__main__":
276361
run_tests()

test/export/test_export.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,8 +1113,8 @@ def foo_fake_impl(a, b):
11131113
# catch concrete inequality
11141114
with self.assertRaisesRegex(
11151115
error_type,
1116-
"Real tensor propagation found an output size mismatch between fake shape 8 and real shape 4, "
1117-
"at output index 0, dimension 0 for func: mylib.foo.default",
1116+
r"Real tensor propagation found an output size mismatch between fake shape 8 and real shape 4, "
1117+
r"at output\.size\(0\), for func: mylib.foo.default",
11181118
):
11191119
export(
11201120
M(),
@@ -1133,8 +1133,8 @@ def foo_fake_impl(a, b):
11331133
)
11341134
with self.assertRaisesRegex(
11351135
error_type,
1136-
"Real tensor propagation found an output size mismatch between fake shape s1 and real shape 4, "
1137-
"at output index 0, dimension 0 for func: mylib.foo.default",
1136+
r"Real tensor propagation found an output size mismatch between fake shape s1 and real shape 4, "
1137+
r"at output\.size\(0\), for func: mylib.foo.default",
11381138
):
11391139
export(
11401140
M(),
@@ -1193,7 +1193,7 @@ def foo_fake_impl(a):
11931193
with self.assertRaisesRegex(
11941194
error_type,
11951195
r"Real tensor propagation found a metadata mismatch between fake tensor (.*\n)*.* "
1196-
r"and real tensor (.*\n)*.* at output index 0, for func: mylib.foo_dtype.default",
1196+
r"and real tensor (.*\n)*.* at output, for func: mylib.foo_dtype.default",
11971197
):
11981198
ep = export(N(), (torch.randn(4, 4),))
11991199

@@ -1415,6 +1415,33 @@ def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
14151415
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
14161416
ep = export(model, inputs)
14171417

1418+
def test_real_tensor_errors_on_aliasing_custom_op(self):
1419+
@torch.library.custom_op("export::foo_alias", mutates_args={})
1420+
def foo(x: torch.Tensor) -> torch.Tensor:
1421+
return x
1422+
1423+
class Foo(torch.nn.Module):
1424+
def forward(self, x):
1425+
return torch.ops.export.foo_alias(x) * 2
1426+
1427+
model = Foo()
1428+
inputs = (torch.randn(4, 4),)
1429+
error_type = (
1430+
RuntimeError
1431+
if is_non_strict_test(self._testMethodName)
1432+
else torch._dynamo.exc.TorchRuntimeError
1433+
)
1434+
with self.assertRaisesRegex(
1435+
error_type,
1436+
(
1437+
r"The output of this custom operator \(1\) must not also be an input "
1438+
r"to this custom operator and \(2\) may not alias any inputs to this "
1439+
r"custom operator or other returns"
1440+
),
1441+
):
1442+
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
1443+
ep = export(model, inputs)
1444+
14181445
@testing.expectedFailureSerDer # SymBool serialization? TODO(pianpwk)
14191446
@testing.expectedFailureSerDerNonStrict
14201447
def test_real_tensor_bool_cast(self):

torch/_functorch/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ def remote_autograd_cache_default() -> Optional[bool]:
205205
# Supported formats are defined here https://graphviz.org/docs/outputs/
206206
torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg")
207207

208+
# Valid only if fake_tensor_propagate_real_tensors = True; if a fake-real
209+
# kernel mismatch is detected, bypasses by making a fake kernel from the
210+
# real tensor outputs.
211+
generate_fake_kernels_from_real_mismatches = False
212+
208213

209214
# Error on BypassAOTAutogradCache instead of just a warning
210215
# Used for tests

0 commit comments

Comments
 (0)